diff --git a/integration-test/setup_test.sh b/integration-test/setup_test.sh new file mode 100644 index 000000000000..e248908ef823 --- /dev/null +++ b/integration-test/setup_test.sh @@ -0,0 +1,74 @@ +#!/bin/bash +# setup_test_resources.sh + +# 创建目录结构 +mkdir -p src/test/resources/timerxl-example +mkdir -p src/test/resources/sundial-example +mkdir -p src/test/resources/legacy-example + +# 生成Python脚本并执行 +cat > generate_test_models.py << 'EOF' +import torch +import torch.nn as nn +from safetensors.torch import save_file + +def create_timerxl_weights(): + weights = { + "embeddings.weight": torch.randn(1000, 512), + "layers.0.attention.query.weight": torch.randn(512, 512), + "layers.0.attention.key.weight": torch.randn(512, 512), + "layers.0.attention.value.weight": torch.randn(512, 512), + "layers.0.attention.output.weight": torch.randn(512, 512), + "layers.0.mlp.gate_proj.weight": torch.randn(2048, 512), + "layers.0.mlp.up_proj.weight": torch.randn(2048, 512), + "layers.0.mlp.down_proj.weight": torch.randn(512, 2048), + "layers.0.input_layernorm.weight": torch.randn(512), + "prediction_head.weight": torch.randn(96, 512), + } + return weights + +def create_sundial_weights(): + weights = { + "patch_embed.proj.weight": torch.randn(512, 1, 16), + "blocks.0.norm1.weight": torch.randn(512), + "blocks.0.attn.qkv.weight": torch.randn(1536, 512), + "blocks.0.attn.proj.weight": torch.randn(512, 512), + "blocks.0.mlp.fc1.weight": torch.randn(2048, 512), + "blocks.0.mlp.fc2.weight": torch.randn(512, 2048), + "head.weight": torch.randn(96, 512), + } + return weights + +class SimpleLegacyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(96, 96) + + def forward(self, x): + return self.linear(x) + +try: + # TimerXL + timerxl_weights = create_timerxl_weights() + save_file(timerxl_weights, "src/test/resources/timerxl-example/model.safetensors") + + # Sundial + sundial_weights = create_sundial_weights() + save_file(sundial_weights, "src/test/resources/sundial-example/model.safetensors") + + # Legacy + model = SimpleLegacyModel() + traced_model = torch.jit.trace(model, torch.randn(1, 96)) + torch.jit.save(traced_model, "src/test/resources/legacy-example/model.pt") + + print("All test model files generated successfully!") + +except ImportError as e: + print(f"Missing dependency: {e}") + print("Please install: pip install torch safetensors") +EOF + +python generate_test_models.py +rm generate_test_models.py + +echo "Test resources setup completed!" \ No newline at end of file diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/CreateModelIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/CreateModelIT.java new file mode 100644 index 000000000000..24ba9e48bdd2 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/CreateModelIT.java @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.io.File; +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class CreateModelIT { + static final String TIMERXL_MODEL_PATH = + System.getProperty("user.dir") + + File.separator + + "src" + + File.separator + + "test" + + File.separator + + "resources" + + File.separator + + "timerxl-example"; + + static final String SUNDIAL_MODEL_PATH = + System.getProperty("user.dir") + + File.separator + + "src" + + File.separator + + "test" + + File.separator + + "resources" + + File.separator + + "sundial-example"; + + // Legacy format paths for backward compatibility testing + static final String LEGACY_MODEL_PATH = + System.getProperty("user.dir") + + File.separator + + "src" + + File.separator + + "test" + + File.separator + + "resources" + + File.separator + + "legacy-example"; + + static String[] setupSqls = + new String[] { + "set configuration \"trusted_uri_pattern\"='.*'", + "CREATE DATABASE root.iotdb.test", + "CREATE TIMESERIES root.iotdb.test.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.iotdb.test.s1 WITH DATATYPE=FLOAT, ENCODING=RLE", + }; + + static String[] timeSeriesData = new String[96]; + + static { + // Generate 96 time series data points + for (int i = 0; i < 96; i++) { + float value = (float) Math.sin(i * 0.1) + (float) Math.random() * 0.1f; + timeSeriesData[i] = + String.format( + "insert into root.iotdb.test(timestamp,s0,s1) values(%d,%.3f,%.3f)", + i + 1, value, value + 0.1f); + } + } + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().initClusterEnvironment(1, 1); + + // Prepare basic setup and data + String[] allSqls = new String[setupSqls.length + timeSeriesData.length]; + System.arraycopy(setupSqls, 0, allSqls, 0, setupSqls.length); + System.arraycopy(timeSeriesData, 0, allSqls, setupSqls.length, timeSeriesData.length); + + prepareData(allSqls); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + private static void checkHeader(ResultSetMetaData resultSetMetaData, String title) + throws SQLException { + String[] headers = title.split(","); + for (int i = 1; i <= resultSetMetaData.getColumnCount(); i++) { + assertEquals(headers[i - 1], resultSetMetaData.getColumnName(i)); + } + } + + @Test + public void timerXLModelOperationTest() { + String registerSql = "create model timerxl_test using uri \"" + TIMERXL_MODEL_PATH + "\""; + String showSql = "SHOW MODELS timerxl_test"; + String dropSql = "DROP MODEL timerxl_test"; + + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + // Register model + statement.execute(registerSql); + + // Wait for model to load + boolean modelReady = false; + int maxRetries = 30; // Wait up to 30 seconds + + for (int i = 0; i < maxRetries; i++) { + try (ResultSet resultSet = statement.executeQuery(showSql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "ModelId,ModelType,State,Configs,Notes"); + + while (resultSet.next()) { + String modelName = resultSet.getString(1); + String modelType = resultSet.getString(2); + String status = resultSet.getString(3); + + assertEquals("timerxl_test", modelName); + assertEquals("USER_DEFINED", modelType); + + if (status.equals("ACTIVE")) { + modelReady = true; + break; + } else if (status.equals("LOADING")) { + Thread.sleep(1000); // Wait 1 second + break; + } else { + fail("Unexpected model status: " + status); + } + } + } + + if (modelReady) break; + } + + assertTrue("Model failed to become ACTIVE within timeout", modelReady); + + // Delete model + statement.execute(dropSql); + + // Verify model is deleted + try (ResultSet resultSet = statement.executeQuery(showSql)) { + int count = 0; + while (resultSet.next()) { + count++; + } + assertEquals(0, count); + } + + } catch (SQLException | InterruptedException e) { + fail(e.getMessage()); + } + } + + @Test + public void timerXLInferenceTest() { + String registerSql = "create model timerxl_inference using uri \"" + TIMERXL_MODEL_PATH + "\""; + String inferenceSql = + "CALL INFERENCE(timerxl_inference, \"select s0 from root.iotdb.test\", generateTime=true)"; + String dropSql = "DROP MODEL timerxl_inference"; + + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + // Register model + statement.execute(registerSql); + + // Wait for model to be ready + Thread.sleep(5000); + + // Execute inference + try (ResultSet resultSet = statement.executeQuery(inferenceSql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + + // Check output columns + assertTrue( + "Should have Time column", resultSetMetaData.getColumnName(1).equals("Time")); + assertTrue( + "Should have at least one output column", resultSetMetaData.getColumnCount() >= 2); + + // Check output data + int rowCount = 0; + while (resultSet.next()) { + rowCount++; + // Verify time column is not null + assertNotNull("Time should not be null", resultSet.getTimestamp(1)); + + // Verify output value is numeric + float outputValue = resultSet.getFloat(2); + assertTrue("Output should be a valid number", !Float.isNaN(outputValue)); + } + + assertTrue("Should have output rows", rowCount > 0); + System.out.println("TimerXL inference generated " + rowCount + " predictions"); + } + + // Cleanup + statement.execute(dropSql); + + } catch (SQLException | InterruptedException e) { + fail(e.getMessage()); + } + } + + @Test + public void sundialModelOperationTest() { + String registerSql = "create model sundial_test using uri \"" + SUNDIAL_MODEL_PATH + "\""; + String showSql = "SHOW MODELS sundial_test"; + String dropSql = "DROP MODEL sundial_test"; + + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + // Register model + statement.execute(registerSql); + + // Wait for model to load + boolean modelReady = false; + int maxRetries = 30; + + for (int i = 0; i < maxRetries; i++) { + try (ResultSet resultSet = statement.executeQuery(showSql)) { + while (resultSet.next()) { + String status = resultSet.getString(3); + if (status.equals("ACTIVE")) { + modelReady = true; + break; + } else if (status.equals("LOADING")) { + Thread.sleep(1000); + break; + } + } + } + if (modelReady) break; + } + + assertTrue("Sundial model failed to become ACTIVE", modelReady); + + // Delete model + statement.execute(dropSql); + + } catch (SQLException | InterruptedException e) { + fail(e.getMessage()); + } + } + + @Test + public void legacyModelCompatibilityTest() { + String registerSql = "create model legacy_test using uri \"" + LEGACY_MODEL_PATH + "\""; + String showSql = "SHOW MODELS legacy_test"; + String dropSql = "DROP MODEL legacy_test"; + + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + // Register legacy model + statement.execute(registerSql); + + // Wait for model to load + boolean modelReady = false; + int maxRetries = 30; + + for (int i = 0; i < maxRetries; i++) { + try (ResultSet resultSet = statement.executeQuery(showSql)) { + while (resultSet.next()) { + String status = resultSet.getString(3); + if (status.equals("ACTIVE")) { + modelReady = true; + break; + } else if (status.equals("LOADING")) { + Thread.sleep(1000); + break; + } + } + } + if (modelReady) break; + } + + assertTrue("Legacy model failed to become ACTIVE", modelReady); + + // Delete model + statement.execute(dropSql); + + } catch (SQLException | InterruptedException e) { + fail(e.getMessage()); + } + } + + @Test + public void iotdbModelFormatDetectionTest() { + // Test that IoTDB format (config.json + safetensors) is detected and used properly + String registerSql = "create model format_test using uri \"" + TIMERXL_MODEL_PATH + "\""; + String showSql = "SHOW MODELS format_test"; + String dropSql = "DROP MODEL format_test"; + + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + // Register model + statement.execute(registerSql); + + // Check that model loads successfully with IoTDB format + boolean modelReady = false; + int maxRetries = 30; + + for (int i = 0; i < maxRetries; i++) { + try (ResultSet resultSet = statement.executeQuery(showSql)) { + while (resultSet.next()) { + String modelName = resultSet.getString(1); + String status = resultSet.getString(3); + + assertEquals("format_test", modelName); + + if (status.equals("ACTIVE")) { + modelReady = true; + System.out.println("IoTDB format model loaded successfully"); + break; + } else if (status.equals("LOADING")) { + Thread.sleep(1000); + break; + } + } + } + if (modelReady) break; + } + + assertTrue("IoTDB format model failed to load", modelReady); + + // Cleanup + statement.execute(dropSql); + + } catch (SQLException | InterruptedException e) { + fail(e.getMessage()); + } + } + + @Test + public void iotdbModelErrorHandlingTest() { + // Test invalid URI + String invalidUriSql = "create model invalid_model using uri \"/nonexistent/path\""; + + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + + try { + statement.execute(invalidUriSql); + fail("Should throw exception for invalid URI"); + } catch (SQLException e) { + assertTrue( + "Should contain error message about invalid URI", + e.getMessage().contains("invalid") || e.getMessage().contains("not found")); + } + + } catch (SQLException e) { + fail("Unexpected error: " + e.getMessage()); + } + } + + private void assertNotNull(String message, Object value) { + if (value == null) { + fail(message); + } + } +} \ No newline at end of file diff --git a/integration-test/src/test/resources/legacy-example/config.yaml b/integration-test/src/test/resources/legacy-example/config.yaml new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/integration-test/src/test/resources/legacy-example/model.pt b/integration-test/src/test/resources/legacy-example/model.pt new file mode 100644 index 000000000000..2bf6e9b9fcee Binary files /dev/null and b/integration-test/src/test/resources/legacy-example/model.pt differ diff --git a/integration-test/src/test/resources/sundial-example/config.json b/integration-test/src/test/resources/sundial-example/config.json new file mode 100644 index 000000000000..d32ad2f07e85 --- /dev/null +++ b/integration-test/src/test/resources/sundial-example/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sundial", + "hidden_size": 512, + "num_hidden_layers": 6, + "num_attention_heads": 8, + "input_token_len": 96, + "output_token_lens": [96], + "hidden_act": "gelu", + "intermediate_size": 2048, + "max_position_embeddings": 10000, + "patch_size": 16, + "num_channels": 1, + "image_size": 96, + "torch_dtype": "float32", + "initializer_range": 0.02, + "use_cache": true, + "iotdb_version": "1.0.0", + "architecture": "SundialModel", + "task": "forecasting", + "flow_loss_depth": 3, + "num_sampling_steps": 50 +} \ No newline at end of file diff --git a/integration-test/src/test/resources/sundial-example/model.safetensors b/integration-test/src/test/resources/sundial-example/model.safetensors new file mode 100644 index 000000000000..759b03972bfb Binary files /dev/null and b/integration-test/src/test/resources/sundial-example/model.safetensors differ diff --git a/integration-test/src/test/resources/timerxl-example/config.json b/integration-test/src/test/resources/timerxl-example/config.json new file mode 100644 index 000000000000..18ffa906c65d --- /dev/null +++ b/integration-test/src/test/resources/timerxl-example/config.json @@ -0,0 +1,18 @@ +{ + "model_type": "timer", + "hidden_size": 512, + "num_hidden_layers": 8, + "num_attention_heads": 8, + "input_token_len": 96, + "output_token_lens": [96], + "hidden_act": "silu", + "intermediate_size": 2048, + "max_position_embeddings": 10000, + "rope_theta": 10000, + "torch_dtype": "float32", + "initializer_range": 0.02, + "use_cache": true, + "iotdb_version": "1.0.0", + "architecture": "TimerModel", + "task": "forecasting" +} \ No newline at end of file diff --git a/integration-test/src/test/resources/timerxl-example/model.safetensors b/integration-test/src/test/resources/timerxl-example/model.safetensors new file mode 100644 index 000000000000..c340ef4252ad Binary files /dev/null and b/integration-test/src/test/resources/timerxl-example/model.safetensors differ diff --git a/iotdb-core/ainode/ainode/core/client.py b/iotdb-core/ainode/ainode/core/client.py index 15385928a84a..dcb10976a9c2 100644 --- a/iotdb-core/ainode/ainode/core/client.py +++ b/iotdb-core/ainode/ainode/core/client.py @@ -246,6 +246,7 @@ def get_ainode_configuration(self, node_id: int) -> map: self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) + # TODO: create model的新ConfigNode回传函数 def update_model_info( self, model_id: str, @@ -255,8 +256,14 @@ def update_model_info( input_length=0, output_length=0, ) -> None: + """ + 更新模型信息到ConfigNode,增强日志记录 + """ if ainode_id is None: ainode_id = [] + + logger.info(f"Updating model info: {model_id}, status: {model_status}") + for _ in range(0, self._RETRY_NUM): try: req = TUpdateModelInfoReq(model_id, model_status, attribute) @@ -269,6 +276,7 @@ def update_model_info( verify_success( status, "An error occurs when calling update model info" ) + logger.info(f"Successfully updated model info for {model_id}") return status except TTransport.TException: logger.warning( @@ -277,4 +285,4 @@ def update_model_info( ) self._config_leader = None self._wait_and_reconnect() - raise TException(self._MSG_RECONNECTION_FAIL) + raise TException(self._MSG_RECONNECTION_FAIL) \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/config.py b/iotdb-core/ainode/ainode/core/config.py index 62de76fcbb83..0ad86c4a758d 100644 --- a/iotdb-core/ainode/ainode/core/config.py +++ b/iotdb-core/ainode/ainode/core/config.py @@ -72,6 +72,11 @@ def __init__(self): self._version_info = AINODE_VERSION_INFO self._build_info = AINODE_BUILD_INFO + + # create model新增:网络加载约束 + self._support_iotdb_models = True + self._iotdb_model_timeout = 300 # 模型加载超时时间(秒) + self._auto_model_format_detection = True def get_cluster_name(self) -> str: return self._cluster_name @@ -145,6 +150,15 @@ def set_ain_target_config_node_list(self, ain_target_config_node_list: str) -> N self._ain_target_config_node_list = parse_endpoint_url( ain_target_config_node_list ) + + def get_support_iotdb_models(self) -> bool: + return self._support_iotdb_models + + def get_iotdb_model_timeout(self) -> int: + return self._iotdb_model_timeout + + def get_auto_model_format_detection(self) -> bool: + return self._auto_model_format_detection @singleton diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index 24d13a12ab8a..dcf34ac87f63 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -62,6 +62,19 @@ STD_LEVEL = logging.INFO +# TODO: 检查模型文件名 +IOTDB_CONFIG_FILES = ["config.json", "configuration.json"] +IOTDB_WEIGHT_FILES = ["model.safetensors", "pytorch_model.safetensors", "model.pt", "pytorch_model.pt"] + +# 辅助类:模型状态常量 +class ModelStatus(Enum): + LOADING = 0 + ACTIVE = 1 + INACTIVE = 2 + ERROR = 3 + + def get_status_code(self) -> int: + return self.value class TSStatusCode(Enum): SUCCESS_STATUS = 200 @@ -70,6 +83,9 @@ class TSStatusCode(Enum): INVALID_URI_ERROR = 1511 INVALID_INFERENCE_CONFIG = 1512 INFERENCE_INTERNAL_ERROR = 1520 + # create model专用的错误码 + MODEL_LOADING_ERROR = 1521 # 新增 + MODEL_FORMAT_ERROR = 1522 # 新增 def get_status_code(self) -> int: return self.value diff --git a/iotdb-core/ainode/ainode/core/exception.py b/iotdb-core/ainode/ainode/core/exception.py index 977b10cfa04f..1aa66c3c518a 100644 --- a/iotdb-core/ainode/ainode/core/exception.py +++ b/iotdb-core/ainode/ainode/core/exception.py @@ -133,6 +133,27 @@ def __init__(self, model_name: str, attribute_name: str): self.message = "Attribute {0} is not supported in model {1}".format( attribute_name, model_name ) + +# create model 补充的异常类 +class ModelLoadingError(_BaseError): + def __init__(self, model_id: str, error_msg: str): + self.message = f"Failed to load model {model_id}: {error_msg}" + + +class ModelFormatError(_BaseError): + def __init__(self, model_path: str, expected_format: str): + self.message = f"Invalid model format at {model_path}, expected {expected_format}" + + +class IoTDBModelError(_BaseError): + def __init__(self, model_type: str, error_msg: str): + self.message = f"IoTDB model error for {model_type}: {error_msg}" + + +class UnsupportedModelTypeError(_BaseError): + def __init__(self, model_type: str): + self.message = f"Unsupported model type: {model_type}. Supported types: timer, sundial" + # This is used to extract the key message in RuntimeError instead of the traceback message diff --git a/iotdb-core/ainode/ainode/core/handler.py b/iotdb-core/ainode/ainode/core/handler.py index 456bc97269a3..267fca2e4713 100644 --- a/iotdb-core/ainode/ainode/core/handler.py +++ b/iotdb-core/ainode/ainode/core/handler.py @@ -32,6 +32,7 @@ TTrainingReq, ) from ainode.thrift.common.ttypes import TSStatus +from ainode.core.constant import TSStatusCode, logging class AINodeRPCServiceHandler(IAINodeRPCService.Iface): @@ -40,6 +41,7 @@ def __init__(self): self._inference_manager = InferenceManager(model_manager=self._model_manager) def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: + # todo: 可能需要增强 return self._model_manager.register_model(req) def deleteModel(self, req: TDeleteModelReq) -> TSStatus: @@ -56,3 +58,24 @@ def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass + + # TODO:处理模型状态回传 + def updateModelStatus(self, model_id: str, status: int, message: str = "") -> TSStatus: + """更新模型状态到ConfigNode""" + try: + from ainode.core.client import ClientManager + from ainode.core.config import AINodeDescriptor + + ClientManager().borrow_config_node_client().update_model_info( + model_id=model_id, + model_status=status, + attribute=message, + ainode_id=[AINodeDescriptor().get_config().get_ainode_id()] + ) + return TSStatus(code=TSStatusCode.SUCCESS_STATUS.get_status_code()) + except Exception as e: + logging.error(f"Failed to update model status: {e}") + return TSStatus( + code=TSStatusCode.AINODE_INTERNAL_ERROR.get_status_code(), + message=str(e) + ) \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index eb8becd0f177..fc2497e3556c 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -52,30 +52,51 @@ def infer(self, full_data, **kwargs): # [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length], # we only get valueList currently. +# TODO: 模型推理 class TimerXLStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): + def infer(self, full_data, predict_length=96, **kwargs): data = full_data[1][0] if data.dtype.byteorder not in ("=", "|"): data = data.byteswap().newbyteorder() seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) - df = pd.DataFrame(output[0]) - return convert_to_binary(df) - + + # 支持IoTDB模型的推理参数 + revin = kwargs.get("revin", True) + max_tokens = kwargs.get("max_new_tokens", predict_length) + + logger.info(f"TimerXL inference: input_shape={seqs.shape}, predict_length={max_tokens}") + + try: + output = self.model.generate(seqs, max_new_tokens=max_tokens, revin=revin) + df = pd.DataFrame(output[0]) + return convert_to_binary(df) + except Exception as e: + logger.error(f"TimerXL inference failed: {e}") + raise InferenceModelInternalError(f"TimerXL inference error: {str(e)}") class SundialStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): + def infer(self, full_data, predict_length=96, **kwargs): data = full_data[1][0] if data.dtype.byteorder not in ("=", "|"): data = data.byteswap().newbyteorder() seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate( - seqs, max_new_tokens=predict_length, num_samples=10, revin=True - ) - df = pd.DataFrame(output[0].mean(dim=0)) - return convert_to_binary(df) + + # 支持IoTDB模型的推理参数 + revin = kwargs.get("revin", True) + max_tokens = kwargs.get("max_new_tokens", predict_length) + num_samples = kwargs.get("num_samples", 10) + + logger.info(f"Sundial inference: input_shape={seqs.shape}, predict_length={max_tokens}, num_samples={num_samples}") + + try: + output = self.model.generate( + seqs, max_new_tokens=max_tokens, num_samples=num_samples, revin=revin + ) + df = pd.DataFrame(output[0].mean(dim=0)) + return convert_to_binary(df) + except Exception as e: + logger.error(f"Sundial inference failed: {e}") + raise InferenceModelInternalError(f"Sundial inference error: {str(e)}") class BuiltInStrategy(InferenceStrategy): @@ -116,16 +137,48 @@ def infer(self, full_data, window_interval=None, window_step=None, **kwargs): return [convert_to_binary(df) for df in results] +# def _get_strategy(model_id, model): +# if model_id == "_timerxl": +# return TimerXLStrategy(model) +# if model_id == "_sundial": +# return SundialStrategy(model) +# if model_id.startswith("_"): +# return BuiltInStrategy(model) +# return RegisteredStrategy(model) + +# 在现有InferenceManager基础上修改策略获取逻辑 + +# create model更新:策略选择加强 def _get_strategy(model_id, model): - if model_id == "_timerxl": + # 支持IoTDB模型的动态策略选择 + if model_id == "_timerxl" or model_id.startswith("timerxl") or "_timer" in model_id.lower(): return TimerXLStrategy(model) - if model_id == "_sundial": + if model_id == "_sundial" or model_id.startswith("sundial") or "_sundial" in model_id.lower(): return SundialStrategy(model) if model_id.startswith("_"): return BuiltInStrategy(model) + + # 对于用户定义的模型,尝试从模型属性中判断类型 + try: + # TODO:这里可以添加逻辑来检查模型的配置文件,以确定应该使用哪种策略 + model_manager = ModelManager() + model_info = model_manager.get_model_status(model_id) + + # 从模型属性中提取模型类型 + if "timer" in model_id.lower() or "timerxl" in str(model_info).lower(): + logger.info(f"Using TimerXL strategy for model {model_id}") + return TimerXLStrategy(model) + elif "sundial" in model_id.lower() or "sundial" in str(model_info).lower(): + logger.info(f"Using Sundial strategy for model {model_id}") + return SundialStrategy(model) + + except Exception as e: + logger.warning(f"Failed to determine model type for {model_id}, using default strategy: {e}") + return RegisteredStrategy(model) + class InferenceManager: def __init__(self, model_manager: ModelManager): diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py b/iotdb-core/ainode/ainode/core/manager/model_manager.py index 95fdda1456b1..fbde9b34c28d 100644 --- a/iotdb-core/ainode/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py @@ -36,36 +36,115 @@ ) from ainode.thrift.common.ttypes import TSStatus +# create model新增:json格式解析 +import json +import threading +import time +import os +from ainode.core.client import ClientManager +from ainode.core.config import AINodeDescriptor +from ainode.core.util.model_utils import ( + detect_model_format, + validate_iotdb_model_config, + get_model_info_from_config, + validate_model_name +) + logger = Logger() class ModelManager: def __init__(self): self.model_storage = ModelStorage() + self._model_status_cache = {} # 缓存模型状态 + self._status_lock = threading.Lock() def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: + + # TODO:验证模型名称是否受到支持 + if not validate_model_name(req.modelId): + return TRegisterModelResp( + get_status(TSStatusCode.INVALID_URI_ERROR, "Invalid model name") + ) + + # 更新模型状态为加载中 + self._update_model_status(req.modelId, "LOADING", "Model registration started") + logger.info(f"register model {req.modelId} from {req.uri}") try: + # 检测模型格式 + from ainode.core.util.model_utils import parse_model_uri + is_network, parsed_uri = parse_model_uri(req.uri) + + # 使用现有的模型存储注册机制 configs, attributes = self.model_storage.register_model( req.modelId, req.uri ) + + # 检查是否为IoTDB格式并验证 + try: + model_dir = self.model_storage._get_model_directory(req.modelId) + format_type, config_file, weight_file = detect_model_format(model_dir) + + if format_type == "iotdb": + config_path = os.path.join(model_dir, config_file) + model_info = get_model_info_from_config(config_path) + + # 验证IoTDB配置 + with open(config_path, 'r', encoding='utf-8') as f: + config_dict = json.load(f) + + if not validate_iotdb_model_config(config_dict): + self._update_model_status(req.modelId, "ERROR", "Invalid IoTDB model configuration") + self.model_storage.delete_model(req.modelId) + return TRegisterModelResp( + get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, "Invalid IoTDB model configuration") + ) + + # 更新attributes以包含IoTDB模型信息 + import ast + try: + attr_dict = ast.literal_eval(attributes) if attributes else {} + except: + attr_dict = {} + + attr_dict.update({ + "format": "iotdb", + "model_type": model_info["model_type"], + "architecture": model_info.get("architecture", ""), + "task": model_info.get("task", "forecasting") + }) + attributes = str(attr_dict) + + logger.info(f"Successfully registered IoTDB model: {req.modelId}, type: {model_info['model_type']}") + + except Exception as e: + logger.warning(f"Format detection failed, treating as legacy model: {e}") + + # 更新模型状态为活跃 + self._update_model_status(req.modelId, "ACTIVE", "Model registration completed") + return TRegisterModelResp( get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes ) + except InvalidUriError as e: logger.warning(e) + self._update_model_status(req.modelId, "ERROR", str(e)) self.model_storage.delete_model(req.modelId) return TRegisterModelResp( get_status(TSStatusCode.INVALID_URI_ERROR, e.message) ) except BadConfigValueError as e: logger.warning(e) + self._update_model_status(req.modelId, "ERROR", str(e)) self.model_storage.delete_model(req.modelId) return TRegisterModelResp( get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) ) except YAMLError as e: logger.warning(e) + self._update_model_status(req.modelId, "ERROR", "YAML parsing error") self.model_storage.delete_model(req.modelId) if hasattr(e, "problem_mark"): mark = e.problem_mark @@ -84,18 +163,66 @@ def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: ) except Exception as e: logger.warning(e) + self._update_model_status(req.modelId, "ERROR", str(e)) self.model_storage.delete_model(req.modelId) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))) def delete_model(self, req: TDeleteModelReq) -> TSStatus: logger.info(f"delete model {req.modelId}") try: + # 更新模型状态为非活跃 + self._update_model_status(req.modelId, "INACTIVE", "Model deletion started") + self.model_storage.delete_model(req.modelId) + + # 从状态缓存中移除 + with self._status_lock: + self._model_status_cache.pop(req.modelId, None) + return get_status(TSStatusCode.SUCCESS_STATUS) except Exception as e: logger.warning(e) + self._update_model_status(req.modelId, "ERROR", str(e)) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + def _update_model_status(self, model_id: str, status: str, message: str = ""): + """更新模型状态并通知ConfigNode""" + try: + with self._status_lock: + self._model_status_cache[model_id] = { + "status": status, + "message": message, + "timestamp": time.time() + } + + # 映射状态到数字码 + status_code_map = { + "LOADING": 0, + "ACTIVE": 1, + "INACTIVE": 2, + "ERROR": 3 + } + + status_code = status_code_map.get(status, 3) + + # 通知ConfigNode + ClientManager().borrow_config_node_client().update_model_info( + model_id=model_id, + model_status=status_code, + attribute=message, + ainode_id=[AINodeDescriptor().get_config().get_ainode_id()] + ) + + logger.info(f"Model {model_id} status updated to {status}: {message}") + + except Exception as e: + logger.error(f"Failed to update model status for {model_id}: {e}") + + def get_model_status(self, model_id: str) -> dict: + """获取模型状态""" + with self._status_lock: + return self._model_status_cache.get(model_id, {"status": "UNKNOWN", "message": ""}) + def load_model(self, model_id: str, acceleration: bool = False) -> Callable: logger.info(f"load model {model_id}") return self.model_storage.load_model(model_id, acceleration) diff --git a/iotdb-core/ainode/ainode/core/model/config_parser.py b/iotdb-core/ainode/ainode/core/model/config_parser.py new file mode 100644 index 000000000000..78a64a48b33a --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/config_parser.py @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from pathlib import Path +from typing import Any, Dict, Union + +import yaml + +from ainode.core.log import Logger + +logger = Logger() + + +def parse_config_file(config_path: Union[str, Path]) -> Dict[str, Any]: + """ + 解析配置文件,支持JSON和YAML格式 + + Args: + config_path: 配置文件路径 + + Returns: + 配置字典 + """ + config_path = Path(config_path) + + if not config_path.exists(): + raise FileNotFoundError(f"配置文件不存在: {config_path}") + + suffix = config_path.suffix.lower() + + try: + with open(config_path, "r", encoding="utf-8") as f: + if suffix == ".json": + return json.load(f) + elif suffix in [".yaml", ".yml"]: + return yaml.safe_load(f) + else: + # 尝试JSON解析 + content = f.read() + try: + return json.loads(content) + except json.JSONDecodeError: + # 尝试YAML解析 + return yaml.safe_load(content) + except Exception as e: + logger.error(f"解析配置文件失败: {config_path}, 错误: {e}") + raise + + +def convert_iotdb_config_to_ainode_format( + iotdb_config: Dict[str, Any] +) -> Dict[str, Any]: + """ + 将 IoTDB 配置转换为 AINode 格式 + + Args: + iotdb_config: IoTDB 配置字典 + + Returns: + AINode 格式的配置字典 + """ + # 提取基础信息 + model_type = iotdb_config.get("model_type", "unknown") + input_length = iotdb_config.get("input_token_len", 96) + output_length = iotdb_config.get("output_token_lens", [96])[0] + + # 转换为AINode格式 + ainode_config = { + "configs": { + "input_shape": [input_length, 1], # IoTDB时序模型输入为单维 + "output_shape": [output_length, 1], # IoTDB时序模型输出为单维 + "input_type": ["float32"], + "output_type": ["float32"], + }, + "attributes": { + "model_type": model_type, + "iotdb_model": True, + "original_config": iotdb_config, + }, + } + + logger.debug(f"转换 IoTDB 配置: {model_type} -> AINode格式") + return ainode_config + + +def apply_config_patches(config: Dict[str, Any], model_type: str) -> Dict[str, Any]: + """ + 应用配置补丁,支持版本兼容 + + Args: + config: 原始配置 + model_type: 模型类型 + + Returns: + 应用补丁后的配置 + """ + patches = { + "timer": { + # TimerXL特定补丁 + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "seq_len": "input_token_len", + }, + "sundial": { + # Sundial特定补丁 + "diff_steps": "num_sampling_steps", + "flow_depth": "flow_loss_depth", + }, + } + + if model_type in patches: + for old_key, new_key in patches[model_type].items(): + if old_key in config and new_key not in config: + config[new_key] = config.pop(old_key) + logger.debug(f"应用补丁: {old_key} -> {new_key}") + + return config \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/model/model_factory.py b/iotdb-core/ainode/ainode/core/model/model_factory.py index 702826f9cba5..800a1f5aaea2 100644 --- a/iotdb-core/ainode/ainode/core/model/model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/model_factory.py @@ -18,6 +18,7 @@ import os import shutil from urllib.parse import urljoin, urlparse +from pathlib import Path import yaml from requests import Session @@ -34,12 +35,17 @@ from ainode.core.log import Logger from ainode.core.util.serde import get_data_type_byte_from_str from ainode.thrift.ainode.ttypes import TConfigs +from ainode.core.model.config_parser import parse_config_file, convert_iotdb_config_to_ainode_format HTTP_PREFIX = "http://" HTTPS_PREFIX = "https://" logger = Logger() +# IoTDB 模型相关文件名 +IOTDB_CONFIG_FILES = ["config.json", "configuration.json"] +IOTDB_WEIGHT_FILES = ["model.safetensors", "pytorch_model.safetensors", "model.pt", "pytorch_model.pt"] + def _parse_uri(uri): """ @@ -90,6 +96,39 @@ def _download_file(url: str, storage_path: str) -> None: logger.debug(f"download file from {url} to {storage_path} success") +def _detect_model_format(base_path: str) -> tuple: + """ + 检测模型格式:legacy (model.pt + config.yaml) 或 IoTDB (config.json + safetensors) + + Args: + base_path: 模型目录路径 + + Returns: + (format_type, config_file, weight_file): 格式类型和对应的文件名 + """ + base_path = Path(base_path) + + # 检查 IoTDB 格式 + for config_file in IOTDB_CONFIG_FILES: + config_path = base_path / config_file + if config_path.exists(): + # 查找权重文件 + for weight_file in IOTDB_WEIGHT_FILES: + weight_path = base_path / weight_file + if weight_path.exists(): + logger.debug(f"检测到 IoTDB 格式: {config_file} + {weight_file}") + return "iotdb", config_file, weight_file + + # 检查 legacy 格式 + legacy_config = base_path / DEFAULT_CONFIG_FILE_NAME + legacy_model = base_path / DEFAULT_MODEL_FILE_NAME + if legacy_config.exists() and legacy_model.exists(): + logger.debug(f"检测到 legacy 格式: {DEFAULT_CONFIG_FILE_NAME} + {DEFAULT_MODEL_FILE_NAME}") + return "legacy", DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME + + return None, None, None + + def _register_model_from_network( uri: str, model_storage_path: str, config_storage_path: str ) -> [TConfigs, str]: @@ -105,19 +144,57 @@ def _register_model_from_network( """ # concat uri to get complete url uri = uri if uri.endswith("/") else uri + "/" - target_model_path = urljoin(uri, DEFAULT_MODEL_FILE_NAME) - target_config_path = urljoin(uri, DEFAULT_CONFIG_FILE_NAME) - - # download config file - _download_file(target_config_path, config_storage_path) + + # 首先尝试检测 IoTDB 格式 + iotdb_detected = False + configs, attributes = None, None + + for config_file in IOTDB_CONFIG_FILES: + try: + target_config_path = urljoin(uri, config_file) + _download_file(target_config_path, config_storage_path) + + # 解析 IoTDB 配置 + iotdb_config = parse_config_file(config_storage_path) + ainode_config = convert_iotdb_config_to_ainode_format(iotdb_config) + configs, attributes = _parse_inference_config(ainode_config) + + # 查找对应的权重文件 + for weight_file in IOTDB_WEIGHT_FILES: + try: + target_model_path = urljoin(uri, weight_file) + _download_file(target_model_path, model_storage_path) + iotdb_detected = True + logger.info(f"成功下载 IoTDB 模型: {config_file} + {weight_file}") + break + except Exception as e: + logger.debug(f"未找到权重文件 {weight_file}: {e}") + continue + + if iotdb_detected: + break + + except Exception as e: + logger.debug(f"未找到配置文件 {config_file}: {e}") + continue + + # 如果未检测到 IoTDB 格式,尝试 legacy 格式 + if not iotdb_detected: + logger.debug("未检测到 IoTDB 格式,尝试 legacy 格式") + target_model_path = urljoin(uri, DEFAULT_MODEL_FILE_NAME) + target_config_path = urljoin(uri, DEFAULT_CONFIG_FILE_NAME) + + # download config file + _download_file(target_config_path, config_storage_path) - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) + # read and parse config dict from config.yaml + with open(config_storage_path, "r", encoding="utf-8") as file: + config_dict = yaml.safe_load(file) + configs, attributes = _parse_inference_config(config_dict) - # if config.yaml is correct, download model file - _download_file(target_model_path, model_storage_path) + # if config.yaml is correct, download model file + _download_file(target_model_path, model_storage_path) + return configs, attributes @@ -134,39 +211,37 @@ def _register_model_from_local( configs: TConfigs attributes: str """ - # concat uri to get complete path - target_model_path = os.path.join(uri, DEFAULT_MODEL_FILE_NAME) - target_config_path = os.path.join(uri, DEFAULT_CONFIG_FILE_NAME) - - # check if file exist - exist_model_file = os.path.exists(target_model_path) - exist_config_file = os.path.exists(target_config_path) - - configs = None - attributes = None - if exist_model_file and exist_config_file: - # copy config.yaml - logger.debug(f"copy file from {target_config_path} to {config_storage_path}") - shutil.copy(target_config_path, config_storage_path) - logger.debug( - f"copy file from {target_config_path} to {config_storage_path} success" - ) - - # read and parse config dict from config.yaml + # 检测模型格式 + format_type, config_file, weight_file = _detect_model_format(uri) + + if format_type is None: + raise InvalidUriError(f"未找到有效的模型文件在路径: {uri}") + + target_config_path = os.path.join(uri, config_file) + target_model_path = os.path.join(uri, weight_file) + + # 复制配置文件 + logger.debug(f"copy file from {target_config_path} to {config_storage_path}") + shutil.copy(target_config_path, config_storage_path) + logger.debug(f"copy file from {target_config_path} to {config_storage_path} success") + + # 解析配置文件 + if format_type == "iotdb": + # IoTDB 格式 + iotdb_config = parse_config_file(config_storage_path) + ainode_config = convert_iotdb_config_to_ainode_format(iotdb_config) + configs, attributes = _parse_inference_config(ainode_config) + else: + # legacy 格式 with open(config_storage_path, "r", encoding="utf-8") as file: config_dict = yaml.safe_load(file) configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, copy model file - logger.debug(f"copy file from {target_model_path} to {model_storage_path}") - shutil.copy(target_model_path, model_storage_path) - logger.debug( - f"copy file from {target_model_path} to {model_storage_path} success" - ) - - elif not exist_model_file or not exist_config_file: - raise InvalidUriError(uri) - + + # 复制模型文件 + logger.debug(f"copy file from {target_model_path} to {model_storage_path}") + shutil.copy(target_model_path, model_storage_path) + logger.debug(f"copy file from {target_model_path} to {model_storage_path} success") + return configs, attributes @@ -288,4 +363,4 @@ def fetch_model_by_uri(uri: str, model_storage_path: str, config_storage_path: s uri, model_storage_path, config_storage_path ) else: - return _register_model_from_local(uri, model_storage_path, config_storage_path) + return _register_model_from_local(uri, model_storage_path, config_storage_path) \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/model/model_loader/__init__.py b/iotdb-core/ainode/ainode/core/model/model_loader/__init__.py new file mode 100644 index 000000000000..16283477eb84 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/model_loader/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from .base_model import ThuTLBaseModel +from ..sundial import SundialConfig, modeling_sundial +from ..timerxl import TimerXLConfig, modeling_timer + +__all__ = ["ThuTLBaseModel", "modeling_timer", "TimerXLConfig", "modeling_sundial", "SundialConfig"] diff --git a/iotdb-core/ainode/ainode/core/model/model_loader/base_model.py b/iotdb-core/ainode/ainode/core/model/model_loader/base_model.py new file mode 100644 index 000000000000..9068700e42b0 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/model_loader/base_model.py @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +thuTL 基础模型抽象,适配 AINode 架构 +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Dict, Iterator, Tuple, Union + +import torch +import torch.nn as nn +from model_config import ModelConfig + + +class BaseModel(ABC, nn.Module): + """thuTL 模型基类,适配 AINode 架构""" + + def __init__(self, config: "ModelConfig"): + super().__init__() + self.config = config + self.build_layers() + + @abstractmethod + def build_layers(self): + """构建模型层,子类必须实现""" + pass + + @abstractmethod + def forward(self, input_ids: torch.Tensor, **kwargs) -> torch.Tensor: + """前向传播,子类必须实现""" + pass + + def preprocess(self, data: Union[torch.Tensor, Any]) -> torch.Tensor: + """预处理输入数据""" + if not isinstance(data, torch.Tensor): + data = torch.tensor(data, dtype=torch.float32) + + if data.dim() == 1: + data = data.unsqueeze(0) + + # 截断或填充到指定长度 + target_len = self.config.input_token_len + if data.size(-1) > target_len: + data = data[..., :target_len] + elif data.size(-1) < target_len: + pad_size = target_len - data.size(-1) + data = torch.nn.functional.pad(data, (0, pad_size)) + + return data + + def postprocess(self, logits: torch.Tensor) -> torch.Tensor: + """后处理输出""" + target_len = self.config.output_token_lens[0] + return logits[..., :target_len] + + @classmethod + def from_pretrained( + cls, config_path: Union[str, Path], weights_path: Union[str, Path], **kwargs + ) -> "BaseModel": + """从预训练模型加载""" + from .model_config import ModelConfig + from .weight_loader import load_weights + + # 加载配置 + config = ModelConfig.from_json(config_path) + + # 创建模型实例 + model = cls(config, **kwargs) + + # 加载权重 + weights = load_weights(weights_path) + missing_keys, unexpected_keys = model.load_state_dict(weights, strict=False) + + if missing_keys: + print(f"Missing keys: {missing_keys}") + if unexpected_keys: + print(f"Unexpected keys: {unexpected_keys}") + + return model + + +# 模型注册机制 +_MODEL_REGISTRY: Dict[str, type] = {} + + +def register_model(name: str): + """模型注册装饰器""" + + def wrapper(cls): + _MODEL_REGISTRY[name] = cls + return cls + + return wrapper + + +def get_model_class(name: str) -> type: + """获取已注册的模型类""" + if name not in _MODEL_REGISTRY: + raise ValueError(f"Unknown model type: {name}") + return _MODEL_REGISTRY[name] + + +def list_available_models() -> list: + """列出所有可用模型""" + return list(_MODEL_REGISTRY.keys()) diff --git a/iotdb-core/ainode/ainode/core/model/model_loader/model_config.py b/iotdb-core/ainode/ainode/core/model/model_loader/model_config.py new file mode 100644 index 000000000000..a5532db4be8f --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/model_loader/model_config.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +thuTL 模型配置管理,支持 HuggingFace 格式和多版本兼容 +""" + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + + +@dataclass +class ModelConfig: + """thuTL 统一模型配置基类""" + + # 基础字段 + model_type: str + hidden_size: int + num_hidden_layers: int + num_attention_heads: int + input_token_len: int + output_token_lens: List[int] + + # 通用字段 + hidden_act: str = "silu" + intermediate_size: int = 2048 + max_position_embeddings: int = 10000 + rope_theta: int = 10000 + torch_dtype: str = "float32" + initializer_range: float = 0.02 + use_cache: bool = True + + # 版本信息 + transformers_version: Optional[str] = None + thuTL_version: Optional[str] = "1.0.0" + + # 未知字段存储 + _extras: Dict[str, Any] = field(default_factory=dict, repr=False) + + @classmethod + def from_json(cls, config_path: Union[str, Path]) -> "ModelConfig": + """从 JSON 配置文件加载""" + config_data = json.loads(Path(config_path).read_text(encoding="utf-8")) + + # 应用版本兼容性补丁 + config_data = cls._apply_version_patches(config_data) + + # 根据模型类型选择具体配置类 + model_type = config_data.get("model_type", "unknown") + config_cls = cls._get_config_class(model_type) + + # 分离已知和未知字段 + known_fields = set() + for cls_in_mro in config_cls.__mro__: + if hasattr(cls_in_mro, "__dataclass_fields__"): + known_fields.update(cls_in_mro.__dataclass_fields__.keys()) + + known = {} + extras = {} + for k, v in config_data.items(): + if k in known_fields and not k.startswith("_"): + known[k] = v + else: + extras[k] = v + + known["_extras"] = extras + return config_cls(**known) + + @classmethod + def _get_config_class(cls, model_type: str) -> type: + """根据模型类型选择配置类""" + if model_type == "timer": + from ..timerxl.configuration_timer import TimerConfig + + return TimerConfig + elif model_type == "sundial": + from ..sundial.configuration_sundial import SundialConfig + + return SundialConfig + else: + return cls # 使用基类作为后备 + + @classmethod + def _apply_version_patches(cls, config_data: Dict[str, Any]) -> Dict[str, Any]: + """应用版本兼容性补丁""" + version = config_data.get("transformers_version", "unknown") + + # 补丁示例:不同版本间的字段映射 + patches = { + # 历史版本兼容 + "legacy": { + "n_embd": "hidden_size", + "n_layer": "num_hidden_layers", + "n_head": "num_attention_heads", + "seq_len": "input_token_len", + }, + # 特定版本补丁 + "4.30.0": {"attention_dropout": "attn_dropout_rate"}, + } + + # 应用通用legacy补丁 + if "legacy" in patches: + for old_key, new_key in patches["legacy"].items(): + if old_key in config_data and new_key not in config_data: + config_data[new_key] = config_data.pop(old_key) + + # 应用版本特定补丁 + if version in patches: + for old_key, new_key in patches[version].items(): + if old_key in config_data and new_key not in config_data: + config_data[new_key] = config_data.pop(old_key) + + return config_data + + def to_dict(self) -> Dict[str, Any]: + """转换为字典格式""" + result = {} + for field_name, field_def in self.__dataclass_fields__.items(): + if not field_name.startswith("_"): + result[field_name] = getattr(self, field_name) + + # 添加扩展字段 + result.update(self._extras) + return result + + def save_json(self, path: Union[str, Path]): + """保存为 JSON 文件""" + Path(path).write_text( + json.dumps(self.to_dict(), indent=2, ensure_ascii=False), encoding="utf-8" + ) + + +# 具体模型配置类将在各自的模块中定义 diff --git a/iotdb-core/ainode/ainode/core/model/model_loader/weight_loader.py b/iotdb-core/ainode/ainode/core/model/model_loader/weight_loader.py new file mode 100644 index 000000000000..b74423fa902c --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/model_loader/weight_loader.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py b/iotdb-core/ainode/ainode/core/model/model_storage.py index c0e2a21c80a8..574eab8e81ce 100644 --- a/iotdb-core/ainode/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/ainode/core/model/model_storage.py @@ -19,6 +19,7 @@ import os import shutil from collections.abc import Callable +from pathlib import Path import torch import torch._dynamo @@ -30,9 +31,15 @@ from ainode.core.log import Logger from ainode.core.model.model_factory import fetch_model_by_uri from ainode.core.util.lock import ModelLockPool +from ainode.core.model.config_parser import parse_config_file +from ainode.core.model.safetensor_loader import load_weights_as_state_dict logger = Logger() +# IoTDB 模型相关文件名 +IOTDB_CONFIG_FILES = ["config.json", "configuration.json"] +IOTDB_WEIGHT_FILES = ["model.safetensors", "pytorch_model.safetensors", "model.pt", "pytorch_model.pt"] + class ModelStorage(object): def __init__(self): @@ -68,39 +75,162 @@ def register_model(self, model_id: str, uri: str): config_storage_path = os.path.join(storage_path, DEFAULT_CONFIG_FILE_NAME) return fetch_model_by_uri(uri, model_storage_path, config_storage_path) + def _detect_model_format(self, model_dir: str) -> tuple: + """ + 检测模型格式:legacy (model.pt + config.yaml) 或 IoTDB (config.json + safetensors) + + Args: + model_dir: 模型目录路径 + + Returns: + (format_type, config_file, weight_file): 格式类型和对应的文件名 + """ + model_path = Path(model_dir) + + # 检查 IoTDB 格式 + for config_file in IOTDB_CONFIG_FILES: + config_path = model_path / config_file + if config_path.exists(): + # 查找权重文件 + for weight_file in IOTDB_WEIGHT_FILES: + weight_path = model_path / weight_file + if weight_path.exists(): + logger.debug(f"检测到 IoTDB 格式: {config_file} + {weight_file}") + return "iotdb", config_file, weight_file + + # 检查 legacy 格式 + legacy_config = model_path / DEFAULT_CONFIG_FILE_NAME + legacy_model = model_path / DEFAULT_MODEL_FILE_NAME + if legacy_config.exists() and legacy_model.exists(): + logger.debug(f"检测到 legacy 格式: {DEFAULT_CONFIG_FILE_NAME} + {DEFAULT_MODEL_FILE_NAME}") + return "legacy", DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME + + return None, None, None + + def _load_iotdb_model(self, model_dir: str, config_file: str, weight_file: str, acceleration: bool) -> Callable: + """ + 加载 IoTDB 格式的模型 + + Args: + model_dir: 模型目录 + config_file: 配置文件名 + weight_file: 权重文件名 + acceleration: 是否启用加速 + + Returns: + 加载的模型 + """ + config_path = os.path.join(model_dir, config_file) + weight_path = os.path.join(model_dir, weight_file) + + # 检查缓存 + cache_key = f"{config_path}:{weight_path}" + if cache_key in self._model_cache: + model = self._model_cache[cache_key] + if (isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration): + return model + else: + model = torch.compile(model) + self._model_cache[cache_key] = model + return model + + try: + # 解析配置文件 + config_dict = parse_config_file(config_path) + model_type = config_dict.get("model_type", "unknown") + + # 根据模型类型动态导入 + if model_type == "timer": + from ainode.model.timerxl import TimerForPrediction as ModelClass + elif model_type == "sundial": + from ainode.model.sundial import SundialForPrediction as ModelClass + else: + raise ValueError(f"不支持的模型类型: {model_type}") + + # 加载模型 + model = ModelClass.from_pretrained(config_path, weight_path) + model.eval() + + # 转换为 TorchScript 以便缓存和部署 + try: + # 创建示例输入 + input_length = config_dict.get("input_token_len", 96) + example_input = torch.randn(1, input_length) + + # 转换为 TorchScript + model = torch.jit.trace(model, example_input) + logger.debug(f"成功转换 IoTDB 模型为 TorchScript: {model_type}") + except Exception as e: + logger.warning(f"TorchScript 转换失败,使用原生模型: {e}") + + # 应用加速 + if acceleration: + try: + model = torch.compile(model) + logger.debug(f"启用模型加速: {model_type}") + except Exception as e: + logger.warning(f"模型加速失败,使用普通模式: {e}") + + # 缓存模型 + self._model_cache[cache_key] = model + return model + + except Exception as e: + logger.error(f"加载 IoTDB 模型失败: {e}") + raise ModelNotExistError(f"无法加载模型: {weight_path}") + + def _load_legacy_model(self, model_path: str, acceleration: bool) -> Callable: + """ + 加载 legacy 格式的模型 + + Args: + model_path: 模型文件路径 + acceleration: 是否启用加速 + + Returns: + 加载的模型 + """ + if model_path in self._model_cache: + model = self._model_cache[model_path] + if (isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration): + return model + else: + model = torch.compile(model) + self._model_cache[model_path] = model + return model + else: + if not os.path.exists(model_path): + raise ModelNotExistError(model_path) + else: + model = torch.jit.load(model_path) + if acceleration: + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + self._model_cache[model_path] = model + return model + def load_model(self, model_id: str, acceleration: bool) -> Callable: """ Returns: model: a ScriptModule contains model architecture and parameters, which can be deployed cross-platform """ ain_models_dir = os.path.join(self._model_dir, f"{model_id}") - model_path = os.path.join(ain_models_dir, DEFAULT_MODEL_FILE_NAME) + with self._lock_pool.get_lock(model_id).read_lock(): - if model_path in self._model_cache: - model = self._model_cache[model_path] - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - else: - model = torch.compile(model) - self._model_cache[model_path] = model - return model + # 检测模型格式 + format_type, config_file, weight_file = self._detect_model_format(ain_models_dir) + + if format_type == "iotdb": + logger.info(f"加载 IoTDB 格式模型: {model_id}") + return self._load_iotdb_model(ain_models_dir, config_file, weight_file, acceleration) + elif format_type == "legacy": + logger.info(f"加载 legacy 格式模型: {model_id}") + legacy_model_path = os.path.join(ain_models_dir, DEFAULT_MODEL_FILE_NAME) + return self._load_legacy_model(legacy_model_path, acceleration) else: - if not os.path.exists(model_path): - raise ModelNotExistError(model_path) - else: - model = torch.jit.load(model_path) - if acceleration: - try: - model = torch.compile(model) - except Exception as e: - logger.warning( - f"acceleration failed, fallback to normal mode: {str(e)}" - ) - self._model_cache[model_path] = model - return model + raise ModelNotExistError(f"未找到有效的模型文件: {ain_models_dir}") def delete_model(self, model_id: str) -> None: """ @@ -112,10 +242,23 @@ def delete_model(self, model_id: str) -> None: storage_path = os.path.join(self._model_dir, f"{model_id}") with self._lock_pool.get_lock(model_id).write_lock(): if os.path.exists(storage_path): - for file_name in os.listdir(storage_path): - self._remove_from_cache(os.path.join(storage_path, file_name)) + # 清理缓存中的所有相关条目 + keys_to_remove = [] + for cache_key in self._model_cache.keys(): + if storage_path in cache_key: + keys_to_remove.append(cache_key) + + for key in keys_to_remove: + del self._model_cache[key] + + # 删除文件 shutil.rmtree(storage_path) + logger.info(f"成功删除模型: {model_id}") def _remove_from_cache(self, file_path: str) -> None: if file_path in self._model_cache: del self._model_cache[file_path] + + def _get_model_directory(self, model_id: str) -> str: + """获取模型目录路径""" + return os.path.join(self._model_dir, model_id) \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/model/safetensor_loader.py b/iotdb-core/ainode/ainode/core/model/safetensor_loader.py new file mode 100644 index 000000000000..7087d7200aed --- /dev/null +++ b/iotdb-core/ainode/ainode/core/model/safetensor_loader.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pathlib import Path +from typing import Dict, Iterator, Tuple, Union + +import torch + +from ainode.core.log import Logger + +logger = Logger() + + +def load_weights_as_state_dict( + weights_path: Union[str, Path] +) -> Dict[str, torch.Tensor]: + """ + 加载权重文件并返回state_dict格式 + 优先级:.safetensors > .pt > .pth > .bin + + Args: + weights_path: 权重文件路径 + + Returns: + 权重字典 {key: tensor} + """ + weights_path = Path(weights_path) + + if weights_path.is_file(): + return _load_single_weight_file(weights_path) + elif weights_path.is_dir(): + return _load_weights_from_directory(weights_path) + else: + raise FileNotFoundError(f"权重路径不存在: {weights_path}") + + +def iter_weights(weights_path: Union[str, Path]) -> Iterator[Tuple[str, torch.Tensor]]: + """ + 迭代器方式加载权重,节省内存 + + Args: + weights_path: 权重路径 + + Yields: + (key, tensor) 元组 + """ + weights = load_weights_as_state_dict(weights_path) + for key, tensor in weights.items(): + yield key, tensor + + +def _load_single_weight_file(file_path: Path) -> Dict[str, torch.Tensor]: + """加载单个权重文件""" + suffix = file_path.suffix.lower() + + logger.debug(f"加载权重文件: {file_path}") + + if suffix == ".safetensors": + return _load_safetensors_file(file_path) + elif suffix in [".pt", ".pth", ".bin"]: + return _load_pytorch_file(file_path) + else: + raise ValueError(f"不支持的权重文件格式: {suffix}") + + +def _load_weights_from_directory(dir_path: Path) -> Dict[str, torch.Tensor]: + """从目录加载权重,按优先级查找""" + priority_patterns = [ + "model.safetensors", + "pytorch_model.safetensors", + "model.pt", + "pytorch_model.pt", + "model.pth", + "pytorch_model.pth", + "pytorch_model.bin", + ] + + # 查找单一文件 + for pattern in priority_patterns: + file_path = dir_path / pattern + if file_path.exists(): + logger.debug(f"找到权重文件: {file_path}") + return _load_single_weight_file(file_path) + + # 查找分片文件 + safetensor_files = list(dir_path.glob("*.safetensors")) + pytorch_files = list(dir_path.glob("*.pt")) + list(dir_path.glob("*.pth")) + + if safetensor_files: + return _load_sharded_safetensors(safetensor_files) + elif pytorch_files: + return _load_sharded_pytorch(pytorch_files) + + raise FileNotFoundError(f"在目录 {dir_path} 中找不到权重文件") + + +def _load_safetensors_file(file_path: Path) -> Dict[str, torch.Tensor]: + """加载SafeTensors格式文件""" + try: + from safetensors import safe_open + except ImportError: + logger.warning("safetensors未安装,尝试使用PyTorch格式") + # 尝试查找对应的.pt文件 + pt_path = file_path.with_suffix(".pt") + if pt_path.exists(): + return _load_pytorch_file(pt_path) + raise ImportError("需要安装safetensors: pip install safetensors") + + weights = {} + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + weights[key] = f.get_tensor(key) + + logger.debug(f"SafeTensors加载完成: {len(weights)} 个参数") + return weights + + +def _load_pytorch_file(file_path: Path) -> Dict[str, torch.Tensor]: + """加载PyTorch格式文件""" + weights = torch.load(file_path, map_location="cpu") + + # 处理不同的权重结构 + if isinstance(weights, dict): + if "state_dict" in weights: + weights = weights["state_dict"] + elif "model" in weights: + weights = weights["model"] + elif "model_state_dict" in weights: + weights = weights["model_state_dict"] + + logger.debug(f"PyTorch权重加载完成: {len(weights)} 个参数") + return weights + + +def _load_sharded_safetensors(file_list: list) -> Dict[str, torch.Tensor]: + """加载分片的SafeTensors文件""" + from safetensors import safe_open + + weights = {} + for file_path in sorted(file_list): + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + weights[key] = f.get_tensor(key) + + logger.debug( + f"分片SafeTensors加载完成: {len(file_list)} 个文件, {len(weights)} 个参数" + ) + return weights + + +def _load_sharded_pytorch(file_list: list) -> Dict[str, torch.Tensor]: + """加载分片的PyTorch文件""" + weights = {} + for file_path in sorted(file_list): + shard_weights = torch.load(file_path, map_location="cpu") + if isinstance(shard_weights, dict): + weights.update(shard_weights) + + logger.debug( + f"分片PyTorch权重加载完成: {len(file_list)} 个文件, {len(weights)} 个参数" + ) + return weights diff --git a/iotdb-core/ainode/ainode/core/util/model_utils.py b/iotdb-core/ainode/ainode/core/util/model_utils.py new file mode 100644 index 000000000000..2c9a0d836929 --- /dev/null +++ b/iotdb-core/ainode/ainode/core/util/model_utils.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import os +from pathlib import Path +from typing import Dict, Any, Tuple, Optional +from urllib.parse import urlparse + +from ainode.core.constant import DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME +from ainode.core.exception import InvalidUriError, ModelFormatError +from ainode.core.log import Logger + +logger = Logger() + +# IoTDB模型文件名常量 +IOTDB_CONFIG_FILES = ["config.json", "configuration.json"] +IOTDB_WEIGHT_FILES = ["model.safetensors", "pytorch_model.safetensors", "model.pt", "pytorch_model.pt"] + +def detect_model_format(model_path: str) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """ + 检测模型格式:legacy (model.pt + config.yaml) 或 IoTDB (config.json + safetensors) + + Args: + model_path: 模型目录路径 + + Returns: + (format_type, config_file, weight_file): 格式类型和对应的文件名 + """ + model_dir = Path(model_path) + + if not model_dir.exists(): + logger.error(f"模型路径不存在: {model_path}") + return None, None, None + + # 检查 IoTDB 格式 + for config_file in IOTDB_CONFIG_FILES: + config_path = model_dir / config_file + if config_path.exists(): + # 查找权重文件 + for weight_file in IOTDB_WEIGHT_FILES: + weight_path = model_dir / weight_file + if weight_path.exists(): + logger.info(f"检测到 IoTDB 格式: {config_file} + {weight_file}") + return "iotdb", config_file, weight_file + + # 检查 legacy 格式 + legacy_config = model_dir / DEFAULT_CONFIG_FILE_NAME + legacy_model = model_dir / DEFAULT_MODEL_FILE_NAME + if legacy_config.exists() and legacy_model.exists(): + logger.info(f"检测到 legacy 格式: {DEFAULT_CONFIG_FILE_NAME} + {DEFAULT_MODEL_FILE_NAME}") + return "legacy", DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME + + logger.warning(f"未能识别模型格式: {model_path}") + return None, None, None + +def validate_iotdb_model_config(config_dict: Dict[str, Any]) -> bool: + """ + 验证IoTDB模型配置的有效性 + + Args: + config_dict: 配置字典 + + Returns: + 是否有效 + """ + required_fields = ["model_type", "input_token_len", "output_token_lens"] + + for field in required_fields: + if field not in config_dict: + logger.error(f"缺少必需字段: {field}") + return False + + model_type = config_dict.get("model_type") + supported_types = ["timer", "sundial"] + if model_type not in supported_types: + logger.error(f"不支持的模型类型: {model_type}, 支持的类型: {supported_types}") + return False + + # 验证输入输出长度 + input_len = config_dict.get("input_token_len") + output_lens = config_dict.get("output_token_lens") + + if not isinstance(input_len, int) or input_len <= 0: + logger.error(f"无效的输入长度: {input_len}") + return False + + if not isinstance(output_lens, list) or len(output_lens) == 0: + logger.error(f"无效的输出长度配置: {output_lens}") + return False + + logger.info(f"IoTDB模型配置验证通过: {model_type}") + return True + +def parse_model_uri(uri: str) -> Tuple[bool, str]: + """ + 解析模型URI,判断是网络路径还是本地路径 + + Args: + uri: 模型URI + + Returns: + (is_network, parsed_uri): 是否为网络路径和解析后的URI + """ + try: + parsed = urlparse(uri) + is_network = parsed.scheme in ("http", "https") + + if is_network: + logger.info(f"检测到网络URI: {uri}") + return True, uri + else: + # 处理本地路径 + if parsed.scheme == "file": + uri = uri[7:] # 移除 file:// + + # 处理 ~ 符号 + uri = os.path.expanduser(uri) + logger.info(f"检测到本地URI: {uri}") + return False, uri + + except Exception as e: + logger.error(f"URI解析失败: {uri}, 错误: {e}") + raise InvalidUriError(uri) + +def get_model_info_from_config(config_path: str) -> Dict[str, Any]: + """ + 从配置文件中提取模型信息 + + Args: + config_path: 配置文件路径 + + Returns: + 模型信息字典 + """ + try: + config_file = Path(config_path) + + if config_file.suffix.lower() == '.json': + # IoTDB格式配置 + with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + + model_info = { + "format": "iotdb", + "model_type": config.get("model_type", "unknown"), + "input_length": config.get("input_token_len", 0), + "output_length": config.get("output_token_lens", [0])[0] if config.get("output_token_lens") else 0, + "architecture": config.get("architecture", ""), + "task": config.get("task", "forecasting") + } + + else: + # Legacy格式配置 + import yaml + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + configs = config.get("configs", {}) + model_info = { + "format": "legacy", + "model_type": "legacy", + "input_length": configs.get("input_shape", [0, 0])[0], + "output_length": configs.get("output_shape", [0, 0])[0], + "architecture": "legacy", + "task": "forecasting" + } + + logger.info(f"提取模型信息: {model_info}") + return model_info + + except Exception as e: + logger.error(f"提取模型信息失败: {config_path}, 错误: {e}") + raise ModelFormatError(config_path, "valid config file") + +def validate_model_name(model_name: str) -> bool: + """ + 验证模型名称的有效性 + + Args: + model_name: 模型名称 + + Returns: + 是否有效 + """ + if not model_name: + return False + + # 允许字母、数字、下划线、连字符 + import re + pattern = r'^[a-zA-Z0-9_-]+$' + + if re.match(pattern, model_name): + logger.debug(f"模型名称验证通过: {model_name}") + return True + else: + logger.error(f"无效的模型名称: {model_name}") + return False + +def create_model_status_message(status: str, details: str = "") -> str: + """ + 创建模型状态消息 + + Args: + status: 状态 + details: 详细信息 + + Returns: + 状态消息 + """ + if details: + return f"模型状态: {status} - {details}" + else: + return f"模型状态: {status}" \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/util/status.py b/iotdb-core/ainode/ainode/core/util/status.py index 37368b0068b1..e30522d79c85 100644 --- a/iotdb-core/ainode/ainode/core/util/status.py +++ b/iotdb-core/ainode/ainode/core/util/status.py @@ -19,6 +19,7 @@ from ainode.core.constant import TSStatusCode from ainode.core.log import Logger from ainode.thrift.common.ttypes import TSStatus +from ainode.core.constant import TSStatusCode def get_status(status_code: TSStatusCode, message: str = None) -> TSStatus: @@ -31,3 +32,24 @@ def verify_success(status: TSStatus, err_msg: str) -> None: if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code(): Logger().warning(err_msg + ", error status is ", status) raise RuntimeError(str(status.code) + ": " + status.message) + +# TODO: create model增强:获取模型状态并检查操作 +def get_model_status(status_code: TSStatusCode, model_id: str = "", message: str = None) -> TSStatus: + """ + 获取模型相关的状态对象 + """ + status = TSStatus(status_code.get_status_code()) + if message: + status.message = f"Model {model_id}: {message}" if model_id else message + else: + status.message = f"Model {model_id} operation completed" if model_id else "Operation completed" + return status + +def verify_model_success(status: TSStatus, model_id: str, operation: str) -> None: + """ + 验证模型操作是否成功 + """ + if status.code != TSStatusCode.SUCCESS_STATUS.get_status_code(): + error_msg = f"Model {model_id} {operation} failed" + Logger().error(error_msg + f", status: {status}") + raise RuntimeError(f"{status.code}: {status.message}") diff --git a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 index 6a0c97ec3216..8023a83aeeef 100644 --- a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 +++ b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/IoTDBSqlParser.g4 @@ -702,6 +702,10 @@ createModel | CREATE MODEL modelType=identifier modelId=identifier (WITH HYPERPARAMETERS LR_BRACKET hparamPair (COMMA hparamPair)* RR_BRACKET)? (FROM MODEL existingModelId=identifier)? ON DATASET LR_BRACKET trainingData RR_BRACKET ; +uriClause + : USING URI uriValue=STRING_LITERAL + ; + trainingData : dataElement(COMMA dataElement)* ; diff --git a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 index 22cb1e0f539c..34de1978f179 100644 --- a/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 +++ b/iotdb-core/antlr/src/main/antlr4/org/apache/iotdb/db/qp/sql/SqlLexer.g4 @@ -510,6 +510,14 @@ MODELS : M O D E L S ; +USING + : U S I N G + ; + +URI + : U R I + ; + MODIFY : M O D I F Y ;