Skip to content

Commit 12429d1

Browse files
authored
fix an error in gemma accuracy example (#351)
1 parent 37ea0a4 commit 12429d1

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

example/gemma/README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,26 +131,28 @@ Based on lm-eval-harness.
131131
| **Parameter** | **Default Value** |
132132
| :---: | :--- |
133133
| **model** | gemma |
134-
| **model_name** | gemma_2b_en |
135-
| **dtype** | bfloat16 |
136-
| **num_beams** | 1 |
134+
| **model_name** | **gemma_2b_en**, gemma_7b_en|
135+
| **dtype** | **bfloat16**, float16, float32 |
136+
| **num_beams** | **1** |
137137
```
138138
git clone https://github.com/EleutherAI/lm-evaluation-harness.git lm_eval
139139
cd lm_eval
140140
git checkout b281b0921b636bc36ad05c0b0b0763bd6dd43463
141141
git apply ../gemma.patch
142142
pip install -r requirements.txt
143+
pip install torch --index-url https://download.pytorch.org/whl/cpu --force-reinstall
144+
export KERAS_BACKEND=jax
143145
python main.py \
144146
--model gemma \
145-
--model_args model_name=gemma_2b_en,dtype=float32,num_beams=1 \
147+
--model_args model_name=gemma_7b_en,dtype=bfloat16,num_beams=4 \
146148
--tasks openbookqa \
147149
--no_cache
148150
```
149151
### Output
150152
```
151-
gemma (model_name=gemma_2b_en,dtype=float32,num_beams=1), limit: None, provide_description: False, num_fewshot: 0, batch_size: None
153+
gemma (model_name=gemma_7b_en,dtype=bfloat16,num_beams=4), limit: None, provide_description: False, num_fewshot: 0, batch_size: None
152154
| Task |Version| Metric |Value| |Stderr|
153155
|----------|------:|--------|----:|---|-----:|
154-
|openbookqa| 0|acc |0.302|± |0.0206|
155-
| | |acc_norm|0.398|± |0.0219|
156+
|openbookqa| 0|acc |0.326|± |0.0210|
157+
| | |acc_norm|0.454|± |0.0223|
156158
```

example/gemma/gemma.patch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ index 8ca27fac..6127ce6c 100644
2020

2121
diff --git a/lm_eval/models/gemma.py b/lm_eval/models/gemma.py
2222
new file mode 100644
23-
index 00000000..bc4540f7
23+
index 00000000..732185c4
2424
--- /dev/null
2525
+++ b/lm_eval/models/gemma.py
2626
@@ -0,0 +1,79 @@
@@ -41,7 +41,7 @@ index 00000000..bc4540f7
4141
+ self.model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
4242
+ if num_beams > 1:
4343
+ from keras_nlp.samplers import BeamSampler
44-
+ model.compile(sampler=BeamSampler(num_beams=args.num_beams))
44+
+ self.model.compile(sampler=BeamSampler(num_beams=num_beams))
4545
+
4646
+ @property
4747
+ def eot_token_id(self):

0 commit comments

Comments
 (0)