Skip to content

Commit 7b8a777

Browse files
committed
fix: formatting style
1 parent 2b51021 commit 7b8a777

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/webapp/databricks.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,11 @@ def _consume_chunk(chunk_obj: Any) -> int | None:
505505
pass
506506
return records
507507

508-
def fetch_model_version(self, model_name: str):
508+
def fetch_model_version(self, catalog_name: str, inst_name: str, model_name: str):
509+
510+
schema = databricksify_inst_name(inst_name)
511+
model_name_path = f"{catalog_name}.{schema}_gold.{model_name}"
512+
509513
try:
510514
w = WorkspaceClient(
511515
host=databricks_vars["DATABRICKS_HOST_URL"],
@@ -520,7 +524,7 @@ def fetch_model_version(self, model_name: str):
520524
raise ValueError(f"setup_new_inst(): Workspace client creation failed: {e}")
521525

522526
model_info = w.model_versions.list(
523-
full_name=model_name,
527+
full_name=model_name_path,
524528
)
525529

526530
return model_info

src/webapp/routers/models.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import traceback
3434
import logging
3535
from ..gcsdbutils import update_db_from_bucket
36+
from ..config import env_vars
3637

3738
from ..gcsutil import StorageControl
3839

@@ -601,9 +602,30 @@ def get_model_versions(
601602
transformed_model_name = str(decode_url_piece(model_name)).strip()
602603
has_access_to_inst_or_err(inst_id, current_user)
603604

605+
local_session.set(sql_session)
606+
query_result = (
607+
local_session.get()
608+
.execute(select(InstTable).where(InstTable.id == str_to_uuid(inst_id)))
609+
.all()
610+
)
611+
if not query_result or len(query_result) == 0:
612+
raise HTTPException(
613+
status_code=status.HTTP_404_NOT_FOUND,
614+
detail="Institution not found.",
615+
)
616+
if len(query_result) > 1:
617+
raise HTTPException(
618+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
619+
detail="Institution duplicates found.",
620+
)
621+
604622
print(f"Initial model name = {model_name}")
605623
print(f"Converted model name {transformed_model_name}")
606624

607-
model_version_info = databricks_control.fetch_model_version(transformed_model_name)
625+
model_version_info = databricks_control.fetch_model_version(
626+
catalog_name=env_vars["CATALOG_NAME"],
627+
inst_name=f"{query_result[0][0].name}",
628+
model_name=transformed_model_name,
629+
)
608630

609631
return model_version_info

0 commit comments

Comments
 (0)