@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
109109 idx = torch .searchsorted (cum_probs , threshold )
110110 except Exception :
111111 indices = (cum_probs >= threshold ).nonzero (as_tuple = True )[0 ]
112- idx = indices [0 ] if indices .numel () > 0 else torch .tensor (len (cum_probs ) - 1 , device = cum_probs .device )
112+ idx = (
113+ indices [0 ]
114+ if indices .numel () > 0
115+ else torch .tensor (len (cum_probs ) - 1 , device = cum_probs .device )
116+ )
113117 return sorted_indices [idx ]
114118
115119 return torch .argmax (data )
@@ -191,41 +195,41 @@ def infinicore_operator(self, logits, out=None, **kwargs):
191195 def run_test (self , device , test_case , config ):
192196 """
193197 Override run_test to handle random_sample's special comparison logic.
194-
198+
195199 For random_sample, if the indices differ but the logits values at those
196200 indices are equal, the result is still considered valid. This handles
197201 cases where multiple valid indices exist due to floating-point precision.
198-
202+
199203 This is necessary because random_sample can return different valid indices
200204 when multiple positions have the same logits value, especially with
201205 low-precision types like bfloat16 due to floating-point rounding.
202206 """
203207 # Clear stored logits before test to ensure fresh generation
204208 self ._current_logits = None
205-
209+
206210 try :
207211 # Try the standard comparison first
208212 # This will call prepare_inputs_and_kwargs which will set self._current_logits
209213 return super ().run_test (device , test_case , config )
210- except AssertionError :
214+ except AssertionError as original_error :
211215 # If standard comparison fails, check if this is a valid case where
212216 # indices differ but logits values are equal
213-
217+
214218 # Only handle if we have stored logits (from prepare_inputs_and_kwargs)
215219 if self ._current_logits is None :
216220 raise
217-
221+
218222 logits_tensor = self ._current_logits
219-
223+
220224 # Re-run operations with the same logits to get results for comparison
221225 # prepare_inputs_and_kwargs will reuse self._current_logits if it exists
222226 from framework .utils import (
223227 infinicore_tensor_from_torch ,
224228 convert_infinicore_to_torch ,
225229 )
226-
230+
227231 inputs , kwargs = self .prepare_inputs_and_kwargs (test_case , device )
228-
232+
229233 # Prepare infinicore inputs
230234 infini_inputs = []
231235 for inp in inputs :
@@ -235,51 +239,51 @@ def run_test(self, device, test_case, config):
235239 infini_inputs .append (infini_tensor )
236240 else :
237241 infini_inputs .append (inp )
238-
242+
239243 infini_kwargs = kwargs .copy ()
240- if "out" in infini_kwargs and isinstance (infini_kwargs ["out" ], torch .Tensor ):
244+ if "out" in infini_kwargs and isinstance (
245+ infini_kwargs ["out" ], torch .Tensor
246+ ):
241247 cloned_out = infini_kwargs ["out" ].clone ().detach ()
242248 infini_kwargs ["out" ] = infinicore_tensor_from_torch (cloned_out )
243-
249+
244250 # Run both operators
245251 torch_result = self .torch_operator (* inputs , ** kwargs )
246252 infini_result = self .infinicore_operator (* infini_inputs , ** infini_kwargs )
247-
253+
248254 # Extract indices from results
249255 comparison_target = test_case .comparison_target
250256 if comparison_target == "out" :
251257 # Compare output tensor from kwargs
252258 ref_idx = kwargs ["out" ].item ()
253259 torch_result_from_infini = convert_infinicore_to_torch (
254- infini_kwargs ["out" ], kwargs [ "out" ]
260+ infini_kwargs ["out" ]
255261 )
256262 ic_idx = torch_result_from_infini .item ()
257263 else :
258264 # Compare return values
259265 ref_idx = torch_result .item ()
260- torch_result_from_infini = convert_infinicore_to_torch (
261- infini_result , torch_result
262- )
266+ torch_result_from_infini = convert_infinicore_to_torch (infini_result )
263267 ic_idx = torch_result_from_infini .item ()
264-
268+
265269 # Check if indices are equal (standard case)
266270 if ic_idx == ref_idx :
267- return
268-
271+ return True , "passed"
272+
269273 # Special case: indices differ but logits values are equal
270274 # This is valid for random_sample when multiple indices have the same logits value
271275 try :
272276 logits_ref = logits_tensor [ref_idx ].item ()
273277 logits_ic = logits_tensor [ic_idx ].item ()
274278 if logits_ic == logits_ref :
275279 # Valid: different indices but same logits value
276- return
280+ return True , "passed"
277281 except (IndexError , RuntimeError ):
278282 # If we can't access the logits, fall through to raise the original error
279283 pass
280-
284+
281285 # If we get here, the results are truly different
282- raise
286+ raise original_error
283287
284288
285289def main ():
0 commit comments