Skip to content

Commit 9bb2895

Browse files
authored
Merge pull request #197 from datakind/feat/case-insensitive-inst-lookup
feat: Add case-insensitive institution name lookup
2 parents 8c4b4a6 + eb6e5d8 commit 9bb2895

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

src/webapp/routers/institutions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pydantic import BaseModel
88
from sqlalchemy.orm import Session
99
from sqlalchemy.future import select
10-
from sqlalchemy import and_, delete
10+
from sqlalchemy import and_, delete, func
1111

1212
from ..utilities import (
1313
has_access_to_inst_or_err,
@@ -152,7 +152,7 @@ def create_institution(
152152
requested_schemas += PDP_SCHEMA_GROUP
153153
# if no schema is set and PDP is not set, we default to custom.
154154
if not requested_schemas:
155-
requested_schemas = {SchemaType.UNKNOWN}
155+
requested_schemas = [SchemaType.UNKNOWN]
156156
local_session.get().add(
157157
InstTable(
158158
name=req.name,
@@ -357,11 +357,17 @@ def read_inst_name(
357357
"""Returns overview data on a specific institution.
358358
359359
The root-level API view. Only visible to users of that institution or Datakinder access types.
360+
361+
Note: Name matching is case-insensitive. The function will match institution names
362+
regardless of the case of the input parameter. If multiple institutions with the same
363+
name (case-insensitive) exist, this will raise an error.
360364
"""
361365
local_session.set(sql_session)
362366
query_result = (
363367
local_session.get()
364-
.execute(select(InstTable).where(InstTable.name == inst_name))
368+
.execute(
369+
select(InstTable).where(func.lower(InstTable.name) == func.lower(inst_name))
370+
)
365371
.all()
366372
)
367373

src/webapp/routers/institutions_test.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import uuid
44
import os
55
from datetime import datetime
6+
from typing import Generator
67
from unittest import mock
78
import pytest
89
import sqlalchemy
@@ -81,7 +82,9 @@ def session_fixture():
8182

8283

8384
@pytest.fixture(name="client")
84-
def client_fixture(session: sqlalchemy.orm.Session):
85+
def client_fixture(
86+
session: sqlalchemy.orm.Session,
87+
) -> Generator[TestClient, None, None]:
8588
"""Unit test mocks setup for a non-DATAKINDER type."""
8689

8790
def get_session_override():
@@ -108,7 +111,9 @@ def databricks_control_override():
108111

109112

110113
@pytest.fixture(name="datakinder_client")
111-
def datakinder_client_fixture(session: sqlalchemy.orm.Session):
114+
def datakinder_client_fixture(
115+
session: sqlalchemy.orm.Session,
116+
) -> Generator[TestClient, None, None]:
112117
"""Unit test mocks setup for a DATAKINDER type."""
113118

114119
def get_session_override():
@@ -134,7 +139,7 @@ def databricks_control_override():
134139
app.dependency_overrides.clear()
135140

136141

137-
def test_read_all_inst(client: TestClient):
142+
def test_read_all_inst(client: TestClient) -> None:
138143
"""Test GET /institutions."""
139144

140145
# Unauthorized.
@@ -146,7 +151,7 @@ def test_read_all_inst(client: TestClient):
146151
)
147152

148153

149-
def test_read_all_inst_datakinder(datakinder_client: TestClient):
154+
def test_read_all_inst_datakinder(datakinder_client: TestClient) -> None:
150155
"""Test GET /institutions using DATAKINDER type."""
151156
# Authorized.
152157
response = datakinder_client.get("/institutions")
@@ -176,7 +181,7 @@ def test_read_all_inst_datakinder(datakinder_client: TestClient):
176181
]
177182

178183

179-
def test_read_inst_by_name(client: TestClient):
184+
def test_read_inst_by_name(client: TestClient) -> None:
180185
"""Test GET /institutions/name/<name>. For various user access types."""
181186
# Unauthorized.
182187
response = client.get("/institutions/name/school_1")
@@ -193,7 +198,48 @@ def test_read_inst_by_name(client: TestClient):
193198
assert response.json() == INSTITUTION_OBJ
194199

195200

196-
def test_read_inst_by_pdp_id(client: TestClient):
201+
def test_read_inst_by_name_case_insensitive(client: TestClient) -> None:
202+
"""Test GET /institutions/name/<name> with case-insensitive matching."""
203+
# Test with different case variations - should all match
204+
test_cases = [
205+
"valid_school", # Original case
206+
"Valid_School", # Title case
207+
"VALID_SCHOOL", # All uppercase
208+
"vAlId_ScHoOl", # Mixed case
209+
]
210+
211+
for name_variant in test_cases:
212+
response = client.get(f"/institutions/name/{name_variant}")
213+
assert response.status_code == 200, f"Failed for variant: {name_variant}"
214+
assert response.json() == INSTITUTION_OBJ, (
215+
f"Response mismatch for variant: {name_variant}"
216+
)
217+
218+
219+
def test_read_inst_by_name_case_insensitive_lowercase(
220+
datakinder_client: TestClient,
221+
) -> None:
222+
"""Test GET /institutions/name/<name> with lowercase input when DB has mixed case."""
223+
# Test that lowercase input matches mixed case in database
224+
# Using datakinder_client since regular client doesn't have access to school_1
225+
response = datakinder_client.get("/institutions/name/school_1")
226+
assert response.status_code == 200
227+
# Verify it matches the institution with name "school_1" (lowercase in DB)
228+
assert response.json()["name"] == "school_1"
229+
230+
231+
def test_read_inst_by_name_case_insensitive_uppercase(
232+
datakinder_client: TestClient,
233+
) -> None:
234+
"""Test GET /institutions/name/<name> with uppercase input."""
235+
# Test that uppercase input matches lowercase in database
236+
# Using datakinder_client since regular client doesn't have access to school_1
237+
response = datakinder_client.get("/institutions/name/SCHOOL_1")
238+
assert response.status_code == 200
239+
assert response.json()["name"] == "school_1"
240+
241+
242+
def test_read_inst_by_pdp_id(client: TestClient) -> None:
197243
"""Test GET /institutions/pdp-id/<pdp_id>. For various user access types."""
198244
# Unauthorized.
199245
response = client.get("/institutions/pdp-id/456")
@@ -210,7 +256,7 @@ def test_read_inst_by_pdp_id(client: TestClient):
210256
assert response.json() == INSTITUTION_OBJ
211257

212258

213-
def test_read_inst(client: TestClient):
259+
def test_read_inst(client: TestClient) -> None:
214260
"""Test GET /institutions/<uuid>. For various user access types."""
215261
# Unauthorized.
216262
response = client.get("/institutions/" + uuid_to_str(UUID_1))

0 commit comments

Comments
 (0)