Skip to content

Commit a572c1c

Browse files
authored
[AINode] Prevent auto_map gets covered and add model_list for AINodeConcurrentForecastIT (#16928)
1 parent 9386468 commit a572c1c

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@
3636
import java.sql.Connection;
3737
import java.sql.SQLException;
3838
import java.sql.Statement;
39+
import java.util.Arrays;
40+
import java.util.List;
3941

40-
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
4142
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
4243
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
4344
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
@@ -48,6 +49,11 @@ public class AINodeConcurrentForecastIT {
4849

4950
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
5051

52+
private static final List<AINodeTestUtils.FakeModelInfo> MODEL_LIST =
53+
Arrays.asList(
54+
new AINodeTestUtils.FakeModelInfo("sundial", "sundial", "builtin", "active"),
55+
new AINodeTestUtils.FakeModelInfo("timer_xl", "timer", "builtin", "active"));
56+
5157
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
5258
"SELECT * FROM FORECAST(model_id=>'%s', targets=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
5359

@@ -78,7 +84,7 @@ private static void prepareDataForTableModel() throws SQLException {
7884

7985
@Test
8086
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
81-
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
87+
for (AINodeTestUtils.FakeModelInfo modelInfo : MODEL_LIST) {
8288
concurrentGPUForecastTest(modelInfo);
8389
}
8490
}

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __repr__(self):
116116
"AutoConfig": "configuration_timer.TimerConfig",
117117
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
118118
},
119+
_transformers_registered=True,
119120
),
120121
"sundial": ModelInfo(
121122
model_id="sundial",
@@ -128,5 +129,6 @@ def __repr__(self):
128129
"AutoConfig": "configuration_sundial.SundialConfig",
129130
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
130131
},
132+
_transformers_registered=True,
131133
),
132134
}

iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,12 @@ def _callback_model_download_result(self, future, model_id: str):
196196
if os.path.exists(config_path):
197197
with open(config_path, "r", encoding="utf-8") as f:
198198
config = json.load(f)
199-
if model_info.model_type == "":
200-
model_info.model_type = config.get("model_type", "")
201-
model_info.auto_map = config.get("auto_map", None)
199+
model_info.model_type = config.get(
200+
"model_type", model_info.model_type
201+
)
202+
model_info.auto_map = config.get(
203+
"auto_map", model_info.auto_map
204+
)
202205
logger.info(
203206
f"Model {model_id} downloaded successfully and is ready to use."
204207
)

0 commit comments

Comments
 (0)