@@ -12,6 +12,10 @@ Models::Models() : db_(cortex::db::Database::GetInstance().db()) {
1212 db_.exec (
1313 " CREATE TABLE IF NOT EXISTS models ("
1414 " model_id TEXT PRIMARY KEY,"
15+ " model_format TEXT,"
16+ " model_source TEXT,"
17+ " status TEXT,"
18+ " engine TEXT,"
1519 " author_repo_id TEXT,"
1620 " branch_name TEXT,"
1721 " path_to_model_yaml TEXT,"
@@ -22,14 +26,40 @@ Models::Models(SQLite::Database& db) : db_(db) {
2226 db_.exec (
2327 " CREATE TABLE IF NOT EXISTS models ("
2428 " model_id TEXT PRIMARY KEY,"
29+ " model_format TEXT,"
30+ " model_source TEXT,"
31+ " status TEXT,"
32+ " engine TEXT,"
2533 " author_repo_id TEXT,"
2634 " branch_name TEXT,"
2735 " path_to_model_yaml TEXT,"
28- " model_alias TEXT UNIQUE );" );
36+ " model_alias TEXT);" );
2937}
30-
3138Models::~Models () {}
3239
40+ std::string Models::StatusToString (ModelStatus status) const {
41+ switch (status) {
42+ case ModelStatus::Remote:
43+ return " remote" ;
44+ case ModelStatus::Downloaded:
45+ return " downloaded" ;
46+ case ModelStatus::Undownloaded:
47+ return " undownloaded" ;
48+ }
49+ return " unknown" ;
50+ }
51+
52+ ModelStatus Models::StringToStatus (const std::string& status_str) const {
53+ if (status_str == " remote" ) {
54+ return ModelStatus::Remote;
55+ } else if (status_str == " downloaded" ) {
56+ return ModelStatus::Downloaded;
57+ } else if (status_str == " undownloaded" ) {
58+ return ModelStatus::Undownloaded;
59+ }
60+ throw std::invalid_argument (" Invalid status string" );
61+ }
62+
3363cpp::result<std::vector<ModelEntry>, std::string> Models::LoadModelList ()
3464 const {
3565 try {
@@ -57,16 +87,21 @@ cpp::result<std::vector<ModelEntry>, std::string> Models::LoadModelListNoLock()
5787 try {
5888 std::vector<ModelEntry> entries;
5989 SQLite::Statement query (db_,
60- " SELECT model_id, author_repo_id, branch_name, "
90+ " SELECT model_id, model_format, model_source, "
91+ " status, engine, author_repo_id, branch_name, "
6192 " path_to_model_yaml, model_alias FROM models" );
6293
6394 while (query.executeStep ()) {
6495 ModelEntry entry;
6596 entry.model = query.getColumn (0 ).getString ();
66- entry.author_repo_id = query.getColumn (1 ).getString ();
67- entry.branch_name = query.getColumn (2 ).getString ();
68- entry.path_to_model_yaml = query.getColumn (3 ).getString ();
69- entry.model_alias = query.getColumn (4 ).getString ();
97+ entry.model_format = query.getColumn (1 ).getString ();
98+ entry.model_source = query.getColumn (2 ).getString ();
99+ entry.status = StringToStatus (query.getColumn (3 ).getString ());
100+ entry.engine = query.getColumn (4 ).getString ();
101+ entry.author_repo_id = query.getColumn (5 ).getString ();
102+ entry.branch_name = query.getColumn (6 ).getString ();
103+ entry.path_to_model_yaml = query.getColumn (7 ).getString ();
104+ entry.model_alias = query.getColumn (8 ).getString ();
70105 entries.push_back (entry);
71106 }
72107 return entries;
@@ -140,7 +175,8 @@ cpp::result<ModelEntry, std::string> Models::GetModelInfo(
140175 const std::string& identifier) const {
141176 try {
142177 SQLite::Statement query (db_,
143- " SELECT model_id, author_repo_id, branch_name, "
178+ " SELECT model_id, model_format, model_source, "
179+ " status, engine, author_repo_id, branch_name, "
144180 " path_to_model_yaml, model_alias FROM models "
145181 " WHERE model_id = ? OR model_alias = ?" );
146182
@@ -149,10 +185,14 @@ cpp::result<ModelEntry, std::string> Models::GetModelInfo(
149185 if (query.executeStep ()) {
150186 ModelEntry entry;
151187 entry.model = query.getColumn (0 ).getString ();
152- entry.author_repo_id = query.getColumn (1 ).getString ();
153- entry.branch_name = query.getColumn (2 ).getString ();
154- entry.path_to_model_yaml = query.getColumn (3 ).getString ();
155- entry.model_alias = query.getColumn (4 ).getString ();
188+ entry.model_format = query.getColumn (1 ).getString ();
189+ entry.model_source = query.getColumn (2 ).getString ();
190+ entry.status = StringToStatus (query.getColumn (3 ).getString ());
191+ entry.engine = query.getColumn (4 ).getString ();
192+ entry.author_repo_id = query.getColumn (5 ).getString ();
193+ entry.branch_name = query.getColumn (6 ).getString ();
194+ entry.path_to_model_yaml = query.getColumn (7 ).getString ();
195+ entry.model_alias = query.getColumn (8 ).getString ();
156196 return entry;
157197 } else {
158198 return cpp::fail (" Model not found: " + identifier);
@@ -164,6 +204,10 @@ cpp::result<ModelEntry, std::string> Models::GetModelInfo(
164204
165205void Models::PrintModelInfo (const ModelEntry& entry) const {
166206 LOG_INFO << " Model ID: " << entry.model ;
207+ LOG_INFO << " Model Format: " << entry.model_format ;
208+ LOG_INFO << " Model Source: " << entry.model_source ;
209+ LOG_INFO << " Status: " << StatusToString (entry.status );
210+ LOG_INFO << " Engine: " << entry.engine ;
167211 LOG_INFO << " Author/Repo ID: " << entry.author_repo_id ;
168212 LOG_INFO << " Branch Name: " << entry.branch_name ;
169213 LOG_INFO << " Path to model.yaml: " << entry.path_to_model_yaml ;
@@ -188,14 +232,18 @@ cpp::result<bool, std::string> Models::AddModelEntry(ModelEntry new_entry,
188232
189233 SQLite::Statement insert (
190234 db_,
191- " INSERT INTO models (model_id, author_repo_id , "
192- " branch_name, path_to_model_yaml, model_alias) VALUES (?, ?, "
193- " ?, ?, ?)" );
235+ " INSERT INTO models (model_id, model_format, model_source, status , "
236+ " engine, author_repo_id, branch_name, path_to_model_yaml, model_alias) "
237+ " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)" );
194238 insert.bind (1 , new_entry.model );
195- insert.bind (2 , new_entry.author_repo_id );
196- insert.bind (3 , new_entry.branch_name );
197- insert.bind (4 , new_entry.path_to_model_yaml );
198- insert.bind (5 , new_entry.model_alias );
239+ insert.bind (2 , new_entry.model_format );
240+ insert.bind (3 , new_entry.model_source );
241+ insert.bind (4 , StatusToString (new_entry.status ));
242+ insert.bind (5 , new_entry.engine );
243+ insert.bind (6 , new_entry.author_repo_id );
244+ insert.bind (7 , new_entry.branch_name );
245+ insert.bind (8 , new_entry.path_to_model_yaml );
246+ insert.bind (9 , new_entry.model_alias );
199247 insert.exec ();
200248
201249 return true ;
@@ -215,14 +263,19 @@ cpp::result<bool, std::string> Models::UpdateModelEntry(
215263 try {
216264 SQLite::Statement upd (db_,
217265 " UPDATE models "
218- " SET author_repo_id = ?, branch_name = ?, "
266+ " SET model_format = ?, model_source = ?, status = ?, "
267+ " engine = ?, author_repo_id = ?, branch_name = ?, "
219268 " path_to_model_yaml = ? "
220269 " WHERE model_id = ? OR model_alias = ?" );
221- upd.bind (1 , updated_entry.author_repo_id );
222- upd.bind (2 , updated_entry.branch_name );
223- upd.bind (3 , updated_entry.path_to_model_yaml );
224- upd.bind (4 , identifier);
225- upd.bind (5 , identifier);
270+ upd.bind (1 , updated_entry.model_format );
271+ upd.bind (2 , updated_entry.model_source );
272+ upd.bind (3 , StatusToString (updated_entry.status ));
273+ upd.bind (4 , updated_entry.engine );
274+ upd.bind (5 , updated_entry.author_repo_id );
275+ upd.bind (6 , updated_entry.branch_name );
276+ upd.bind (7 , updated_entry.path_to_model_yaml );
277+ upd.bind (8 , identifier);
278+ upd.bind (9 , identifier);
226279 return upd.exec () == 1 ;
227280 } catch (const std::exception& e) {
228281 return cpp::fail (e.what ());
@@ -305,4 +358,5 @@ bool Models::HasModel(const std::string& identifier) const {
305358 return false ;
306359 }
307360}
308- } // namespace cortex::db
361+
362+ } // namespace cortex::db
0 commit comments