Skip to content

Commit 1e8dd7a

Browse files
committed
update error handler
1 parent 9868b41 commit 1e8dd7a

File tree

3 files changed

+51
-30
lines changed

3 files changed

+51
-30
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup
7575
callInferenceTest(
7676
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
7777
dropUserDefinedModel(statement);
78+
errorTest(
79+
statement,
80+
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
81+
"1505: 't5' is already used by a Transformers config, pick another name.");
7882
}
7983
}
8084

@@ -86,6 +90,10 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru
8690
forecastTableFunctionTest(
8791
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
8892
dropUserDefinedModel(statement);
93+
errorTest(
94+
statement,
95+
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
96+
"1505: 't5' is already used by a Transformers config, pick another name.");
8997
}
9098
}
9199

iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ def register_model(
6161
return TRegisterModelResp(
6262
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
6363
)
64+
except Exception as e:
65+
# Catch-all for other exceptions (mainly from transformers implementation)
66+
return TRegisterModelResp(
67+
get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e))
68+
)
6469

6570
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
66-
self._refresh()
6771
return self._model_storage.show_models(req)
6872

6973
def delete_model(self, req: TDeleteModelReq) -> TSStatus:

iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -227,17 +227,22 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str):
227227
model_type = config.get("model_type", "")
228228
auto_map = config.get("auto_map", None)
229229
pipeline_cls = config.get("pipeline_cls", "")
230-
230+
model_info = ModelInfo(
231+
model_id=model_id,
232+
model_type=model_type,
233+
category=ModelCategory.USER_DEFINED,
234+
state=ModelStates.ACTIVE,
235+
pipeline_cls=pipeline_cls,
236+
auto_map=auto_map,
237+
transformers_registered=False, # Lazy registration
238+
)
239+
with self._lock_pool.get_lock(model_id).write_lock():
240+
self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
241+
if self.ensure_transformers_registered(model_id) is None:
242+
model_info.state = ModelStates.INACTIVE
243+
else:
244+
model_info.transformers_registered = True
231245
with self._lock_pool.get_lock(model_id).write_lock():
232-
model_info = ModelInfo(
233-
model_id=model_id,
234-
model_type=model_type,
235-
category=ModelCategory.USER_DEFINED,
236-
state=ModelStates.ACTIVE,
237-
pipeline_cls=pipeline_cls,
238-
auto_map=auto_map,
239-
transformers_registered=False, # Lazy registration
240-
)
241246
self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
242247

243248
# ==================== Registration Methods ====================
@@ -254,6 +259,7 @@ def register_model(self, model_id: str, uri: str):
254259
Raises:
255260
ModelExistedException: If the model_id already exists.
256261
InvalidModelUriException: If the URI format is invalid.
262+
Exception: For other errors during transformers model registration.
257263
"""
258264

259265
if self.is_model_registered(model_id):
@@ -291,25 +297,30 @@ def register_model(self, model_id: str, uri: str):
291297
)
292298
self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
293299

294-
if auto_map:
295-
# Transformers model: immediately register to Transformers autoloading mechanism
296-
success = self._register_transformers_model(model_info)
297-
if success:
298-
with self._lock_pool.get_lock(model_id).write_lock():
299-
model_info.transformers_registered = True
300-
else:
301-
with self._lock_pool.get_lock(model_id).write_lock():
300+
if auto_map:
301+
# Transformers model: immediately register to Transformers autoloading mechanism
302+
try:
303+
if self._register_transformers_model(model_info):
304+
model_info.transformers_registered = True
305+
except Exception as e:
302306
model_info.state = ModelStates.INACTIVE
303-
logger.error(f"Failed to register Transformers model {model_id}")
304-
else:
305-
# Other type models: only log
306-
self._register_other_model(model_info)
307+
logger.error(
308+
f"Failed to register Transformers model {model_id}, because {e}"
309+
)
310+
raise e
311+
else:
312+
# Other type models: only log
313+
self._register_other_model(model_info)
307314

308315
logger.info(f"Successfully registered model {model_id} from URI: {uri}")
309316

310-
def _register_transformers_model(self, model_info: ModelInfo):
317+
def _register_transformers_model(self, model_info: ModelInfo) -> bool:
311318
"""
312319
Register Transformers model to autoloading mechanism (internal method)
320+
Returns:
321+
True if registration is successful
322+
Raises:
323+
Exception: Transformers internal exception if registration fails
313324
"""
314325
auto_map = model_info.auto_map
315326
if not auto_map:
@@ -344,7 +355,7 @@ def _register_transformers_model(self, model_info: ModelInfo):
344355
logger.warning(
345356
f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct."
346357
)
347-
return False
358+
raise e
348359

349360
def _register_other_model(self, model_info: ModelInfo):
350361
"""Register other type models (non-Transformers models)"""
@@ -354,10 +365,9 @@ def _register_other_model(self, model_info: ModelInfo):
354365

355366
def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None:
356367
"""
357-
Ensure Transformers model is registered (called for lazy registration)
358-
This method uses locks to ensure thread safety. All check logic is within lock protection.
368+
Ensure Transformers model is registered.
359369
Returns:
360-
str: If None, registration failed, otherwise returns model path
370+
ModelInfo | None: None if registration failed, otherwise returns the corresponding ModelInfo
361371
"""
362372
# Use lock to protect entire check-execute process
363373
with self._lock_pool.get_lock(model_id).write_lock():
@@ -385,8 +395,7 @@ def ensure_transformers_registered(self, model_id: str) -> ModelInfo | None:
385395

386396
# Execute registration (under lock protection)
387397
try:
388-
success = self._register_transformers_model(model_info)
389-
if success:
398+
if self._register_transformers_model(model_info):
390399
model_info.transformers_registered = True
391400
logger.info(
392401
f"Model {model_id} successfully registered to Transformers"

0 commit comments

Comments
 (0)