Skip to content

Commit 9673ee2

Browse files
committed
added database fix
1 parent dcfc11d commit 9673ee2

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

eval/eval.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def setup_custom_parser():
3535

3636
db_group.add_argument("--model_id", type=str, default=None, help="Model UUID for direct database tracking")
3737

38-
parser.add_argument("--use_database", action="store_true", help="Where to use PostgreSQL Database to track results.")
38+
parser.add_argument(
39+
"--use_database", action="store_true", help="Where to use PostgreSQL Database to track results."
40+
)
3941
parser.add_argument(
4042
"--model_name",
4143
type=str,
@@ -232,12 +234,12 @@ def cli_evaluate(args: Optional[argparse.Namespace] = None) -> None:
232234
args.hf_hub_log_args += f",output_path={args.output_path}"
233235
evaluation_tracker = setup_evaluation_tracker(args.output_path, args.use_database)
234236

235-
# If model_id is provided, lookup model name from database
237+
# If model_id is provided, lookup model weights location from database
236238
if args.model_id:
237239
if not args.use_database:
238240
raise ValueError("--use_database must be set to use --model_id.")
239241
try:
240-
model_name = evaluation_tracker.get_model_name_from_db(args.model_id)
242+
model_name = evaluation_tracker.get_model_attribute_from_db(args.model_id, "weights_location")
241243
args.model_args = update_model_args_with_name(args.model_args or "", model_name)
242244
utils.eval_logger.info(f"Retrieved model name from database: {model_name}")
243245
except Exception as e:

eval/eval_tracker.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -370,27 +370,30 @@ def update_evalresults_db(
370370
session=session,
371371
)
372372

373-
def get_model_name_from_db(
374-
self,
375-
model_id: str,
376-
) -> str:
373+
def get_model_attribute_from_db(self, model_id: str, attribute: str) -> str:
377374
"""
378-
Retrieve model name from database using model_id.
375+
Retrieve a specific attribute from a model in the database.
379376
380377
Args:
381378
model_id: UUID string of the model
379+
attribute: Name of the attribute to retrieve (e.g., 'name', 'weights_location')
382380
383381
Returns:
384-
str: Model name from database
382+
str: Value of the requested attribute
385383
386384
Raises:
387-
RuntimeError: If model_id is not found in database or if database operation fails
385+
RuntimeError: If model_id is not found in database or if attribute doesn't exist
386+
ValueError: If model_id is not a valid UUID
388387
"""
389388
with self.session_scope() as session:
390389
try:
391390
model = session.get(Model, uuid.UUID(model_id))
392391
if model is None:
393392
raise RuntimeError(f"Model with id {model_id} not found in database")
394-
return model.name
393+
if not hasattr(model, attribute):
394+
raise RuntimeError(f"Attribute '{attribute}' does not exist on Model")
395+
return getattr(model, attribute)
396+
except ValueError as e:
397+
raise ValueError(f"Invalid UUID format: {str(e)}")
395398
except Exception as e:
396-
raise RuntimeError(f"Database error in get_model_name_from_db: {str(e)}")
399+
raise RuntimeError(f"Database error in get_model_attribute_from_db: {str(e)}")

0 commit comments

Comments
 (0)