@@ -151,14 +151,16 @@ def test_model_7b_logits_bf16(self):
151
151
{
152
152
("xpu" , 3 ): torch .tensor ([[- 6.5208 , - 4.1218 , - 4.9377 , - 3.2536 , 0.8127 , - 2.9811 , 1.2918 , - 3.3848 ]]),
153
153
("cuda" , 7 ): torch .tensor ([[- 6.5061 , - 4.1147 , - 4.9669 , - 3.2038 , 0.8069 , - 2.9694 , 1.2864 , - 3.3786 ]]),
154
- ("cuda" , 8 ): torch .tensor ([[- 6.5208 , - 4.1218 , - 4.9377 , - 3.2536 , 0.8127 , - 2.9811 , 1.2918 , - 3.3848 ]])
155
- })
154
+ ("cuda" , 8 ): torch .tensor ([[- 6.5208 , - 4.1218 , - 4.9377 , - 3.2536 , 0.8127 , - 2.9811 , 1.2918 , - 3.3848 ]]),
155
+ ("rocm" , (9 , 4 )): torch .tensor ([[- 6.5094 , - 4.1329 , - 4.9754 , - 3.5042 , 0.8082 , - 2.9443 , 1.2830 , - 3.3539 ]]),
156
+ })
156
157
157
- expected_mean = expected_means .get_expectation ()
158
+ expected_mean = expected_means .get_expectation ().to (torch_device )
159
+ actual_mean = out .logits .float ().mean (- 1 )
158
160
self .assertTrue (
159
161
torch .allclose (
160
- expected_mean . to ( torch_device ) ,
161
- out . logits . float (). mean ( - 1 ) ,
162
+ expected_mean ,
163
+ actual_mean ,
162
164
atol = 1e-2 ,
163
165
rtol = 1e-2
164
166
)
@@ -169,18 +171,13 @@ def test_model_7b_logits_bf16(self):
169
171
{
170
172
("xpu" , 3 ): torch .tensor ([[- 12.5625 , - 7.1250 , - 0.6289 , - 7.8750 , - 6.9688 , - 7.8125 , - 6.5000 , - 7.4375 , - 7.6562 , - 6.9688 , - 6.0312 , - 7.0312 , - 1.8203 , 1.8750 , - 8.5000 ]]),
171
173
("cuda" , 7 ): torch .tensor ([[- 12.5000 , - 7.0625 , - 0.6289 , - 7.8750 , - 6.9688 , - 7.8125 , - 6.4688 , - 7.4375 , - 7.6875 , - 6.9375 , - 6.0312 , - 7.0000 , - 1.8594 , 1.8438 , - 8.5000 ]]),
172
- ("cuda" , 8 ): torch .tensor ([[- 12.5625 , - 7.1250 , - 0.6289 , - 7.8750 , - 6.9688 , - 7.8125 , - 6.5000 , - 7.4375 , - 7.6562 , - 6.9688 , - 6.0312 , - 7.0312 , - 1.8203 , 1.8750 , - 8.5000 ]])
174
+ ("cuda" , 8 ): torch .tensor ([[- 12.5625 , - 7.1250 , - 0.6289 , - 7.8750 , - 6.9688 , - 7.8125 , - 6.5000 , - 7.4375 , - 7.6562 , - 6.9688 , - 6.0312 , - 7.0312 , - 1.8203 , 1.8750 , - 8.5000 ]]),
175
+ ("rocm" , (9 , 4 )): torch .tensor ([[- 12.5000 , - 7.0625 , - 0.6289 , - 7.8750 , - 6.9688 , - 7.8125 , - 6.5000 , - 7.4375 , - 7.6562 , - 6.9375 , - 6.0312 , - 7.0312 , - 1.8594 , 1.8438 , - 8.5000 ]])
173
176
})
174
177
# fmt: on
175
- expected_slice = expected_slices .get_expectation ()
176
- self .assertTrue (
177
- torch .allclose (
178
- expected_slice .to (torch_device ),
179
- out .logits [0 , 0 , :15 ].float (),
180
- atol = 1e-2 ,
181
- rtol = 1e-2 ,
182
- )
183
- )
178
+ expected_slice = expected_slices .get_expectation ().to (torch_device )
179
+ actual_slice = out .logits [0 , 0 , :15 ].float ()
180
+ self .assertTrue (torch .allclose (expected_slice , actual_slice , atol = 1e-2 , rtol = 1e-2 ))
184
181
185
182
@slow
186
183
def test_model_7b_logits (self ):
0 commit comments