Skip to content

Commit ea7e672

Browse files
committed
delete config_cls and model_cls, only maintain auto_map
1 parent c100996 commit ea7e672

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ def __init__(
2828
category: ModelCategory,
2929
state: ModelStates,
3030
model_type: str = "",
31-
config_cls: str = "",
32-
model_cls: str = "",
3331
pipeline_cls: str = "",
3432
repo_id: str = "",
3533
auto_map: Optional[Dict] = None,
@@ -39,8 +37,6 @@ def __init__(
3937
self.model_type = model_type
4038
self.category = category
4139
self.state = state
42-
self.config_cls = config_cls
43-
self.model_cls = model_cls
4440
self.pipeline_cls = pipeline_cls
4541
self.repo_id = repo_id
4642
self.auto_map = auto_map # If exists, indicates it's a Transformers model
@@ -114,19 +110,23 @@ def __repr__(self):
114110
category=ModelCategory.BUILTIN,
115111
state=ModelStates.INACTIVE,
116112
model_type="timer",
117-
config_cls="configuration_timer.TimerConfig",
118-
model_cls="modeling_timer.TimerForPrediction",
119113
pipeline_cls="pipeline_timer.TimerPipeline",
120114
repo_id="thuml/timer-base-84m",
115+
auto_map={
116+
"AutoConfig": "configuration_timer.TimerConfig",
117+
"AutoModelForCausalLM": "modeling_timer.TimerForPrediction",
118+
},
121119
),
122120
"sundial": ModelInfo(
123121
model_id="sundial",
124122
category=ModelCategory.BUILTIN,
125123
state=ModelStates.INACTIVE,
126124
model_type="sundial",
127-
config_cls="configuration_sundial.SundialConfig",
128-
model_cls="modeling_sundial.SundialForPrediction",
129125
pipeline_cls="pipeline_sundial.SundialPipeline",
130126
repo_id="thuml/sundial-base-128m",
127+
auto_map={
128+
"AutoConfig": "configuration_sundial.SundialConfig",
129+
"AutoModelForCausalLM": "modeling_sundial.SundialForPrediction",
130+
},
131131
),
132132
}

iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,22 +69,25 @@ def load_model_from_transformers(model_info: ModelInfo, **model_kwargs):
6969
model_info.model_id,
7070
)
7171

72+
config_str = model_info.auto_map.get("AutoConfig", "")
73+
model_str = model_info.auto_map.get("AutoModelForCausalLM", "")
74+
7275
if model_info.category == ModelCategory.BUILTIN:
7376
module_name = (
7477
AINodeDescriptor().get_config().get_ain_models_builtin_dir()
7578
+ "."
7679
+ model_info.model_id
7780
)
78-
config_cls = import_class_from_path(module_name, model_info.config_cls)
79-
model_cls = import_class_from_path(module_name, model_info.model_cls)
80-
elif model_info.model_cls and model_info.config_cls:
81+
config_cls = import_class_from_path(module_name, config_str)
82+
model_cls = import_class_from_path(module_name, model_str)
83+
elif model_str and config_str:
8184
module_parent = str(Path(model_path).parent.absolute())
8285
with temporary_sys_path(module_parent):
8386
config_cls = import_class_from_path(
84-
model_info.model_id, model_info.config_cls
87+
model_info.model_id, config_str
8588
)
8689
model_cls = import_class_from_path(
87-
model_info.model_id, model_info.model_cls
90+
model_info.model_id, model_str
8891
)
8992
else:
9093
config_cls = AutoConfig.from_pretrained(model_path)

0 commit comments

Comments
 (0)