|
| 1 | +package modelengine.fit.jade.aipp.model.repository.impl; |
| 2 | + |
| 3 | +import modelengine.fit.jade.aipp.model.mapper.ModelMapper; |
| 4 | +import modelengine.fit.jade.aipp.model.mapper.UserModelMapper; |
| 5 | +import modelengine.fit.jade.aipp.model.po.ModelPo; |
| 6 | +import modelengine.fit.jade.aipp.model.po.UserModelDetailPo; |
| 7 | +import modelengine.fit.jade.aipp.model.po.UserModelPo; |
| 8 | +import modelengine.fit.jade.aipp.model.repository.UserModelPluginRepo; |
| 9 | +import modelengine.fit.jade.aipp.model.repository.UserModelRepo; |
| 10 | +import modelengine.fitframework.annotation.Component; |
| 11 | +import modelengine.fitframework.annotation.Fitable; |
| 12 | +import modelengine.fitframework.annotation.Property; |
| 13 | +import modelengine.fitframework.log.Logger; |
| 14 | +import modelengine.fitframework.util.CollectionUtils; |
| 15 | +import modelengine.jade.carver.tool.annotation.Attribute; |
| 16 | +import modelengine.jade.carver.tool.annotation.Group; |
| 17 | +import modelengine.jade.carver.tool.annotation.ToolMethod; |
| 18 | + |
| 19 | +import java.util.*; |
| 20 | +import java.util.stream.Collectors; |
| 21 | + |
| 22 | +/** |
| 23 | + * 表示用户模型信息用于插件的持久化层的接口 {@link UserModelRepo} 的实现。 |
| 24 | + * |
| 25 | + * @author lizhichao |
| 26 | + * @since 2025/4/9 |
| 27 | + */ |
| 28 | +@Component |
| 29 | +@Group(name = "User_Model_Tool_Impl") |
| 30 | +public class UserModelPluginRepoImpl implements UserModelPluginRepo { |
| 31 | + private static final Logger log = Logger.get(UserModelRepoImpl.class); |
| 32 | + private static final String FITABLE_ID = "aipp.model.repository"; |
| 33 | + public static final String DEFAULT_MODEL_TYPE = "chat_completions"; |
| 34 | + private final ModelMapper modelMapper; |
| 35 | + private final UserModelMapper userModelMapper; |
| 36 | + |
| 37 | + /** |
| 38 | + * 构造方法。 |
| 39 | + * |
| 40 | + * @param modelMapper 模型信息表的 MyBatis 映射接口,用于处理模型增删查改。 |
| 41 | + * @param userModelMapper 用户与模型绑定关系的 MyBatis 映射接口,用于管理用户模型映射数据。 |
| 42 | + */ |
| 43 | + public UserModelPluginRepoImpl(ModelMapper modelMapper, UserModelMapper userModelMapper) { |
| 44 | + this.modelMapper = modelMapper; |
| 45 | + this.userModelMapper = userModelMapper; |
| 46 | + } |
| 47 | + |
| 48 | + @Override |
| 49 | + @Fitable(id = FITABLE_ID) |
| 50 | + @ToolMethod(name = "获取用户模型列表", description = "根据用户标识来查询该用户可用的模型列表", extensions = { |
| 51 | + @Attribute(key = "tags", value = "FIT"), @Attribute(key = "tags", value = "MODEL") |
| 52 | + }) |
| 53 | + @Property(description = "返回该用户可用的模型列表") |
| 54 | + public List<UserModelDetailPo> getUserModelList(String userId) { |
| 55 | + log.info("start get model list for {}.", userId); |
| 56 | + List<UserModelPo> userModelPos = this.userModelMapper.listUserModels(userId); |
| 57 | + if (CollectionUtils.isEmpty(userModelPos)) { |
| 58 | + log.warn("No user model records found for userId={}.", userId); |
| 59 | + return Collections.emptyList(); |
| 60 | + } |
| 61 | + List<String> modelIds = userModelPos.stream() |
| 62 | + .map(UserModelPo::getModelId) |
| 63 | + .distinct() |
| 64 | + .collect(Collectors.toList()); |
| 65 | + List<ModelPo> modelPos = this.modelMapper.listModels(modelIds); |
| 66 | + // 构建 modelId → ModelPo 映射 |
| 67 | + Map<String, ModelPo> modelMap = modelPos.stream() |
| 68 | + .map(model -> Map.entry(model.getModelId(), model)) |
| 69 | + .collect(Collectors.toMap( |
| 70 | + Map.Entry::getKey, |
| 71 | + Map.Entry::getValue, |
| 72 | + (a, b) -> a |
| 73 | + )); |
| 74 | + return userModelPos.stream().map(userModel -> { |
| 75 | + ModelPo model = modelMap.get(userModel.getModelId()); |
| 76 | + return new UserModelDetailPo( |
| 77 | + userModel.getCreatedAt(), |
| 78 | + userModel.getModelId(), |
| 79 | + userModel.getUserId(), |
| 80 | + model != null ? model.getName() : null, |
| 81 | + model != null ? model.getBaseUrl() : null, |
| 82 | + userModel.getIsDefault() |
| 83 | + ); |
| 84 | + }).collect(Collectors.toList()); |
| 85 | + } |
| 86 | + |
| 87 | + @Override |
| 88 | + @Fitable(id = FITABLE_ID) |
| 89 | + @ToolMethod(name = "添加模型", description = "为用户添加可用的模型信息", extensions = { |
| 90 | + @Attribute(key = "tags", value = "FIT"), @Attribute(key = "tags", value = "MODEL") |
| 91 | + }) |
| 92 | + @Property(description = "为用户添加可用的模型信息") |
| 93 | + public String addUserModel(String userId, String apiKey, |
| 94 | + String modelName, String baseUrl) { |
| 95 | + log.info("start add user model for {}.", userId); |
| 96 | + String modelId = UUID.randomUUID().toString().replace("-", ""); |
| 97 | + int isDefault = this.userModelMapper.userHasDefaultModel(userId) ? 0 : 1; |
| 98 | + |
| 99 | + ModelPo modelPo = new ModelPo(modelId, modelName, modelId, baseUrl, DEFAULT_MODEL_TYPE); |
| 100 | + modelPo.setCreatedBy(userId); |
| 101 | + modelPo.setUpdatedBy(userId); |
| 102 | + this.modelMapper.insertModel(modelPo); |
| 103 | + |
| 104 | + UserModelPo userModelPo = new UserModelPo(userId, modelId, apiKey, isDefault); |
| 105 | + userModelPo.setCreatedBy(userId); |
| 106 | + userModelPo.setUpdatedBy(userId); |
| 107 | + this.userModelMapper.addUserModel(userModelPo); |
| 108 | + return "添加模型成功。"; |
| 109 | + } |
| 110 | + |
| 111 | + @Override |
| 112 | + @Fitable(id = FITABLE_ID) |
| 113 | + @ToolMethod(name = "删除模型", description = "删除用户绑定的模型信息", extensions = { |
| 114 | + @Attribute(key = "tags", value = "FIT"), @Attribute(key = "tags", value = "MODEL") |
| 115 | + }) |
| 116 | + @Property(description = "删除用户绑定的模型信息") |
| 117 | + public String deleteUserModel(String userId, String modelId) { |
| 118 | + log.info("start delete user model for {}.", userId); |
| 119 | + List<UserModelPo> userModels = this.userModelMapper.listUserModels(userId); |
| 120 | + if (userModels == null || userModels.isEmpty()) { |
| 121 | + return "删除模型失败,当前用户没有任何模型记录。"; |
| 122 | + } |
| 123 | + |
| 124 | + UserModelPo target = userModels.stream() |
| 125 | + .filter(m -> Objects.equals(m.getModelId(), modelId)) |
| 126 | + .findFirst() |
| 127 | + .orElse(null); |
| 128 | + if (target == null) { |
| 129 | + return "删除模型失败,该模型不属于当前用户。"; |
| 130 | + } |
| 131 | + this.userModelMapper.deleteByModelId(modelId); |
| 132 | + this.modelMapper.deleteByModelId(modelId); |
| 133 | + // 如果删除的不是默认模型,直接返回 |
| 134 | + if (target.getIsDefault() != 1) { |
| 135 | + return "删除模型成功。"; |
| 136 | + } |
| 137 | + userModels.remove(target); |
| 138 | + // 如果没有默认模型,但还有其他记录,则设置最新创建的为默认 |
| 139 | + if (!userModels.isEmpty()) { |
| 140 | + UserModelPo latest = userModels.stream() |
| 141 | + .max(Comparator.comparing(UserModelPo::getCreatedAt)) |
| 142 | + .orElse(null); |
| 143 | + this.userModelMapper.switchDefaultForUser(userId, latest.getModelId()); |
| 144 | + return String.format("删除默认模型成功,添加%s为默认模型。", this.modelMapper.get(latest.getModelId()).getName()); |
| 145 | + } |
| 146 | + return "删除模型成功,当前无默认模型。"; |
| 147 | + } |
| 148 | + |
| 149 | + @Override |
| 150 | + @Fitable(id = FITABLE_ID) |
| 151 | + @ToolMethod(name = "切换默认模型", description = "将指定模型设置为用户的默认模型", extensions = { |
| 152 | + @Attribute(key = "tags", value = "FIT"), @Attribute(key = "tags", value = "MODEL") |
| 153 | + }) |
| 154 | + @Property(description = "将指定模型设置为用户的默认模型") |
| 155 | + public String switchDefaultModel(String userId, String modelId) { |
| 156 | + log.info("start switch default model for {}.", userId); |
| 157 | + int rows = this.userModelMapper.switchDefaultForUser(userId, modelId); |
| 158 | + if (rows == 0) { |
| 159 | + return "未查到对应模型。"; |
| 160 | + } |
| 161 | + return String.format("已切换%s为默认模型。", this.modelMapper.get(modelId).getName()); |
| 162 | + } |
| 163 | +} |
0 commit comments