@@ -69,29 +69,66 @@ def summaryp(
6969 except (ValueError , AttributeError , ZeroDivisionError ) as err :
7070 print (f"{ type (err ).__name__ } : { str (err )} " )
7171 return None
72- if isinstance (word_error_rate_breakdown [0 ], np .ndarray ):
73- word_error_rate_breakdown = word_error_rate_breakdown .tolist ()
74- transform_word_error_rate_breakdown = np .transpose (word_error_rate_breakdown )
75- weighted_insertions = transform_word_error_rate_breakdown [3 ] * insertions_weight
76- weighted_deletions = transform_word_error_rate_breakdown [4 ] * deletions_weight
77- weighted_substitutions = (
78- transform_word_error_rate_breakdown [5 ] * substitutions_weight
79- )
80- m = transform_word_error_rate_breakdown [2 ]
81- weighted_errors = sum (
82- (weighted_insertions , weighted_deletions , weighted_substitutions )
83- )
84- werps_result = (weighted_errors / m ).tolist ()
72+
73+ b = word_error_rate_breakdown
74+
75+ # Unwrap 0-D container
76+ if isinstance (b , np .ndarray ) and b .ndim == 0 :
77+ b = b .item ()
78+
79+ if isinstance (b , np .ndarray ):
80+ if b .ndim == 2 :
81+ # True 2-D numeric batch
82+ word_error_rate_breakdown = b .tolist ()
83+ t = b .T
84+ weighted_insertions = t [3 ] * insertions_weight
85+ weighted_deletions = t [4 ] * deletions_weight
86+ weighted_substitutions = t [5 ] * substitutions_weight
87+ m = t [2 ]
88+ weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions
89+ werps_result = (weighted_errors / m ).tolist ()
90+
91+ elif b .ndim == 1 :
92+ # Could be either:
93+ # (a) single example row vector, or
94+ # (b) object array of per-example vectors
95+ first = b [0 ] if b .size else None
96+
97+ if isinstance (first , (np .ndarray , list , tuple )):
98+ # Batch stored as 1-D object array of per-example vectors (ragged fields exist)
99+ word_error_rate_breakdown = []
100+ werps_result = []
101+ for r in b :
102+ rr = r .tolist () if isinstance (r , np .ndarray ) else r
103+ word_error_rate_breakdown .append (rr )
104+ w_ins = float (rr [3 ]) * insertions_weight
105+ w_del = float (rr [4 ]) * deletions_weight
106+ w_sub = float (rr [5 ]) * substitutions_weight
107+ m_val = float (rr [2 ])
108+ weighted_wer = (w_ins + w_del + w_sub ) / m_val if m_val else 0.0
109+ werps_result .append (weighted_wer )
110+ else :
111+ # Single example vector - wrap in list for DataFrame
112+ word_error_rate_breakdown = [b .tolist ()]
113+ weighted_insertions = b [3 ] * insertions_weight
114+ weighted_deletions = b [4 ] * deletions_weight
115+ weighted_substitutions = b [5 ] * substitutions_weight
116+ m = b [2 ]
117+ weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions
118+ werps_result = float (weighted_errors / m ) if m else 0.0
119+
120+ else :
121+ raise ValueError (f"Unexpected metrics output ndim: { b .ndim } " )
122+
85123 else :
86- word_error_rate_breakdown = [word_error_rate_breakdown .tolist ()]
87- weighted_insertions = word_error_rate_breakdown [0 ][3 ] * insertions_weight
88- weighted_deletions = word_error_rate_breakdown [0 ][4 ] * deletions_weight
89- weighted_substitutions = word_error_rate_breakdown [0 ][5 ] * substitutions_weight
90- m = word_error_rate_breakdown [0 ][2 ]
91- weighted_errors = sum (
92- (weighted_insertions , weighted_deletions , weighted_substitutions )
93- )
94- werps_result = weighted_errors / m
124+ # Non-numpy fallback (assume [wer, ld, m, ...])
125+ word_error_rate_breakdown = [b .tolist () if hasattr (b , 'tolist' ) else b ]
126+ weighted_insertions = b [3 ] * insertions_weight
127+ weighted_deletions = b [4 ] * deletions_weight
128+ weighted_substitutions = b [5 ] * substitutions_weight
129+ m = b [2 ]
130+ weighted_errors = weighted_insertions + weighted_deletions + weighted_substitutions
131+ werps_result = float (weighted_errors / m ) if m else 0.0
95132
96133 columns = [
97134 "wer" ,
0 commit comments