Skip to content

Commit 6af11ba

Browse files
committed
Fix tests
1 parent 66a3b05 commit 6af11ba

File tree

1 file changed

+37
-57
lines changed

1 file changed

+37
-57
lines changed

tests/library/test_gen.py

Lines changed: 37 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -74,75 +74,53 @@ def test_stop_quote(selected_model):
7474

7575

7676
def test_metrics_smoke(selected_model: models.Model):
77-
lm = selected_model.copy()
78-
lm.reset_metrics()
77+
lm = selected_model
78+
lm.engine.reset_metrics()
7979

8080
lm += "abcd"
81-
print(f"{lm.engine_metrics=}")
81+
print(f"{lm.engine.metrics=}")
8282
lm += gen("first", max_tokens=1)
83-
print(f"{lm.engine_metrics=}")
83+
print(f"{lm.engine.metrics=}")
8484
# Can't be sure of exact count due to token healing
8585
assert (
86-
lm.engine_metrics.generated_tokens == 1
87-
or lm.engine_metrics.generated_tokens == 2
86+
lm.engine.metrics.model_output_tokens == 1
87+
or lm.engine.metrics.model_output_tokens == 2
8888
)
89-
assert lm.engine_metrics.forced_tokens == 0
89+
assert lm.engine.metrics.model_input_tokens > 1
9090

9191
lm += "fg"
9292
lm += gen("second", max_tokens=1)
9393
# Again, trouble with healing
9494
assert (
95-
lm.engine_metrics.generated_tokens >= 2
96-
and lm.engine_metrics.generated_tokens <= 4
95+
lm.engine.metrics.model_output_tokens == 1
96+
or lm.engine.metrics.model_output_tokens == 2
97+
)
98+
assert (
99+
lm.engine.metrics.model_output_tokens >= 2
100+
or lm.engine.metrics.model_output_tokens <= 4
97101
)
98-
assert lm.engine_metrics.forced_tokens == 0
99102

100103

101104
def test_metrics_select(selected_model: models.Model):
102-
lm = selected_model.copy()
103-
lm.reset_metrics()
104-
105-
lm += "This is a great day to "
106-
lm += select(["ride a bike", "row a boat", "go for a swim"])
107-
print(f"lm={str(lm)}")
108-
print(f"{lm.engine_metrics=}")
109-
assert lm.engine_metrics.forced_tokens > 0
110-
assert lm.engine_metrics.generated_tokens > 0
111-
assert lm.engine_metrics.forced_tokens > lm.engine_metrics.generated_tokens
112-
prev_stats = lm.engine_metrics.copy()
113-
lm += " and afterwards "
114-
lm += select(["walk to town", "walk to a show"])
115-
print(f"lm={str(lm)}")
116-
print(f"{lm.engine_metrics=}")
117-
assert lm.engine_metrics.forced_tokens > prev_stats.forced_tokens
118-
assert lm.engine_metrics.generated_tokens > prev_stats.generated_tokens
119-
120-
121-
def test_metrics_alt_expressions(selected_model: models.Model):
122-
lm = selected_model.copy()
123-
lm2 = selected_model.copy()
124-
lm.reset_metrics()
125-
lm2.reset_metrics()
126-
127-
prompt = "abcdefg"
128-
129-
lm += prompt + gen(max_tokens=10)
130-
print(f"\nlm={str(lm)}")
131-
print(f"{lm.engine_metrics=}\n")
132-
133-
lm2 += prompt
134-
lm2 += gen(max_tokens=10)
135-
print(f"\nlm2={str(lm2)}")
136-
print(f"{lm2.engine_metrics=}\n")
137-
138-
assert str(lm) == str(lm2)
139-
assert lm.engine_metrics.generated_tokens == 10
140-
assert lm2.engine_metrics.generated_tokens == 10
141-
142-
assert (
143-
lm.engine_metrics.forced_tokens + lm.engine_metrics.model_input_tokens
144-
== lm2.engine_metrics.forced_tokens + lm2.engine_metrics.model_input_tokens
105+
lm = selected_model
106+
lm.engine.reset_metrics()
107+
108+
lm += "I will "
109+
lm += select(
110+
[
111+
"ride a bicycle down the road",
112+
"row in a boat along the river",
113+
"go for a swim in the ocean",
114+
]
145115
)
116+
print(f"lm={str(lm)}")
117+
print(f"{lm.engine.metrics=}")
118+
assert lm.engine.metrics.model_input_tokens > 1
119+
assert lm.engine.metrics.model_output_tokens > 0
120+
# Guidance should be able to force the generation after only a couple of tokens
121+
# so even though the options are long, relatively few output tokens should be
122+
# needed
123+
assert lm.engine.metrics.model_input_tokens > lm.engine.metrics.model_output_tokens
146124

147125

148126
def test_unicode(selected_model):
@@ -159,14 +137,16 @@ def test_unicode(selected_model):
159137

160138
def test_unicode2(selected_model: models.Model):
161139
lm = selected_model
162-
lm.reset_metrics()
140+
lm.engine.reset_metrics()
163141
prompt = "Janet’s ducks lay 16 eggs per day"
164142
lm += prompt + gen(max_tokens=10)
143+
assert lm.engine.metrics.model_input_tokens > 1
144+
# Due to token healing, we can't be sure of the
145+
# precise output count
165146
assert (
166-
lm.engine_metrics.generated_tokens == 10
167-
or lm.engine_metrics.generated_tokens == 11
147+
lm.engine.metrics.model_output_tokens == 10
148+
or lm.engine.metrics.model_output_tokens == 11
168149
)
169-
assert lm.engine_metrics.forced_tokens == 0
170150

171151

172152
def test_gsm8k():

0 commit comments

Comments
 (0)