Skip to content

Commit 901144f

Browse files
authored
✅ Added Tests for EnvironmentBasedModelMapper (#422)
While working to confirm that the logic in codemodder-java and the platform is in sync, I found opportunities to add missing tests and fix some pitfalls. /towards #work
1 parent 4a98fa1 commit 901144f

File tree

2 files changed

+79
-5
lines changed

2 files changed

+79
-5
lines changed
Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
11
package io.codemodder.plugins.llm;
22

33
import java.util.HashMap;
4+
import java.util.Map;
45

56
/** Mapper that maps models to their deployment names based on environment variables. */
67
final class EnvironmentBasedModelMapper implements ModelMapper {
7-
private static final String DEPLOYMENT_TEMPLATE = "CODEMODDER_AZURE_OPENAI_%s_DEPLOYMENT";
88

99
private final HashMap<Model, String> map = new HashMap<>();
1010

1111
EnvironmentBasedModelMapper() {
12-
for (Model m : StandardModel.values()) {
13-
final var deployment = System.getenv(String.format(DEPLOYMENT_TEMPLATE, m));
14-
map.put(m, deployment == null ? m.id() : deployment);
12+
this(System.getenv());
13+
}
14+
15+
EnvironmentBasedModelMapper(final Map<String, String> environment) {
16+
for (final Model model : StandardModel.values()) {
17+
final var name = String.format(DEPLOYMENT_TEMPLATE, toEnvironmentVariableCase(model.id()));
18+
final var deployment = environment.getOrDefault(name, model.id());
19+
map.put(model, deployment);
1520
}
1621
}
1722

1823
@Override
1924
public String getModelName(Model model) {
20-
return map.get(model);
25+
return map.getOrDefault(model, model.id());
2126
}
27+
28+
/** Converts a model ID to environment variable casing. */
29+
private static String toEnvironmentVariableCase(String input) {
30+
return input.toUpperCase().replace('-', '_').replace('.', '_');
31+
}
32+
33+
private static final String DEPLOYMENT_TEMPLATE = "CODEMODDER_AZURE_OPENAI_%s_DEPLOYMENT";
2234
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package io.codemodder.plugins.llm;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.mockito.Mockito.mock;
5+
import static org.mockito.Mockito.when;
6+
import static org.mockito.Mockito.withSettings;
7+
8+
import java.util.Map;
9+
import org.junit.jupiter.api.BeforeAll;
10+
import org.junit.jupiter.api.Test;
11+
import org.junit.jupiter.api.TestInstance;
12+
import org.junit.jupiter.params.ParameterizedTest;
13+
import org.junit.jupiter.params.provider.EnumSource;
14+
15+
/** Unit tests for {@link EnvironmentBasedModelMapper}. */
16+
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
17+
final class EnvironmentBasedModelMapperTest {
18+
19+
private EnvironmentBasedModelMapper mapper;
20+
21+
@BeforeAll
22+
void before() {
23+
final var environment =
24+
Map.of(
25+
"CODEMODDER_AZURE_OPENAI_GPT_3_5_TURBO_0125_DEPLOYMENT",
26+
"my-gpt-3.5-turbo",
27+
"CODEMODDER_AZURE_OPENAI_GPT_4_0613_DEPLOYMENT",
28+
"my-gpt-4",
29+
"CODEMODDER_AZURE_OPENAI_GPT_4_TURBO_2024_04_09_DEPLOYMENT",
30+
"my-gpt-4-turbo",
31+
"CODEMODDER_AZURE_OPENAI_GPT_4O_2024_05_13_DEPLOYMENT",
32+
"my-gpt-4o");
33+
mapper = new EnvironmentBasedModelMapper(environment);
34+
}
35+
36+
/** Spot checks one of the standard models to make sure the mapping works as expected */
37+
@Test
38+
void it_maps_model_name_to_deployment() {
39+
final var name = mapper.getModelName(StandardModel.GPT_3_5_TURBO_0125);
40+
assertThat(name).isEqualTo("my-gpt-3.5-turbo");
41+
}
42+
43+
/**
44+
* This is a meta-test that fails when we add a new standard model but forget to update the
45+
* mapping in {@link #before()} to ensure that all standard models are covered.
46+
*/
47+
@EnumSource(StandardModel.class)
48+
@ParameterizedTest
49+
void it_looks_up_all_standard_models(final Model model) {
50+
final var name = mapper.getModelName(model);
51+
assertThat(name).isNotEqualTo(model.id()).startsWith("my-gpt");
52+
}
53+
54+
@Test
55+
void it_returns_model_id_when_no_mapping_exists() {
56+
// GIVEN some model that doesn't have a mapping
57+
final var model = mock(Model.class, withSettings().stubOnly());
58+
when(model.id()).thenReturn("test");
59+
final var name = mapper.getModelName(model);
60+
assertThat(name).isEqualTo(model.id());
61+
}
62+
}

0 commit comments

Comments
 (0)