@@ -1005,35 +1005,41 @@ def get_ns(n: int):
10051005        ) as  prediction_pipeline :
10061006            for  n , batch_size , inputs , exptected_output_shape  in  generate_test_cases ():
10071007                error : Optional [str ] =  None 
1008-                 result  =  prediction_pipeline .predict_sample_without_blocking (inputs )
1009-                 if  len (result .members ) !=  len (exptected_output_shape ):
1010-                     error  =  (
1011-                         f"Expected { len (exptected_output_shape )}   outputs," 
1012-                         +  f" but got { len (result .members )}  " 
1013-                     )
1014- 
1008+                 try :
1009+                     result  =  prediction_pipeline .predict_sample_without_blocking (inputs )
1010+                 except  Exception  as  e :
1011+                     error  =  str (e )
10151012                else :
1016-                     for   m ,  exp   in   exptected_output_shape . items ( ):
1017-                         res  =  result . members . get ( m ) 
1018-                         if   res   is   None : 
1019-                             error   =   "Output tensors may not be None for test case "
1020-                              break 
1013+                     if   len ( result . members )  !=   len ( exptected_output_shape ):
1014+                         error  =  ( 
1015+                              f"Expected  { len ( exptected_output_shape ) }  outputs," 
1016+                             +   f" but got  { len ( result . members ) }  "
1017+                         ) 
10211018
1022-                         diff : Dict [AxisId , int ] =  {}
1023-                         for  a , s  in  res .sizes .items ():
1024-                             if  isinstance ((e_aid  :=  exp [AxisId (a )]), int ):
1025-                                 if  s  !=  e_aid :
1019+                     else :
1020+                         for  m , exp  in  exptected_output_shape .items ():
1021+                             res  =  result .members .get (m )
1022+                             if  res  is  None :
1023+                                 error  =  "Output tensors may not be None for test case" 
1024+                                 break 
1025+ 
1026+                             diff : Dict [AxisId , int ] =  {}
1027+                             for  a , s  in  res .sizes .items ():
1028+                                 if  isinstance ((e_aid  :=  exp [AxisId (a )]), int ):
1029+                                     if  s  !=  e_aid :
1030+                                         diff [AxisId (a )] =  s 
1031+                                 elif  (
1032+                                     s  <  e_aid .min 
1033+                                     or  e_aid .max  is  not   None 
1034+                                     and  s  >  e_aid .max 
1035+                                 ):
10261036                                    diff [AxisId (a )] =  s 
1027-                             elif  (
1028-                                 s  <  e_aid .min  or  e_aid .max  is  not   None  and  s  >  e_aid .max 
1029-                             ):
1030-                                 diff [AxisId (a )] =  s 
1031-                         if  diff :
1032-                             error  =  (
1033-                                 f"(n={ n }  ) Expected output shape { exp }  ," 
1034-                                 +  f" but got { res .sizes }   (diff: { diff }  )" 
1035-                             )
1036-                             break 
1037+                             if  diff :
1038+                                 error  =  (
1039+                                     f"(n={ n }  ) Expected output shape { exp }  ," 
1040+                                     +  f" but got { res .sizes }   (diff: { diff }  )" 
1041+                                 )
1042+                                 break 
10371043
10381044                model .validation_summary .add_detail (
10391045                    ValidationDetail (
0 commit comments