@@ -232,15 +232,20 @@ def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None):
232232 match layout :
233233 case WGMMARowFragLayout ():
234234 if len (shape ) != 1 :
235- raise ValueError
235+ raise ValueError ( "WGMMARowFragLayout requires a 1D shape" )
236236 if shape [0 ] % 64 :
237- raise ValueError
237+ raise ValueError (
238+ "WGMMARowFragLayout requires shape[0] to be a multiple of 64"
239+ )
238240 reg_shape = (shape [0 ] // 64 , 2 )
239241 case WGMMAFragLayout ():
240242 if len (shape ) != 2 :
241- raise ValueError
243+ raise ValueError ( "WGMMAFragLayout requires a 2D shape" )
242244 if shape [0 ] % 64 or shape [1 ] % 8 :
243- raise ValueError
245+ raise ValueError (
246+ "WGMMAFragLayout requires shape[0] to be a multiple of 64, and"
247+ " shape[1] to be a multiple of 8"
248+ )
244249 reg_shape = (shape [0 ] // 64 , shape [1 ] // 8 , 2 , 1 )
245250 value = vector .splat (ir .VectorType .get ((2 ,), value .type ), value )
246251 case WGStridedFragLayout (vec_size = vec_size ):
@@ -283,6 +288,22 @@ def mlir_dtype(self):
283288 case WGMMARowFragLayout () | WGSplatFragLayout ():
284289 return reg_ty
285290
291+ def to_layout (self , new_layout : FragmentedLayout ):
292+ """Converts the fragmented array to the given layout.
293+
294+ At the moment, only conversions from ``WGSplatFragLayout`` are supported.
295+ """
296+ if self .layout == new_layout :
297+ return self
298+ if not isinstance (self .layout , WGSplatFragLayout ):
299+ raise NotImplementedError (
300+ f"Cannot convert from { self .layout } to { new_layout } "
301+ )
302+ [reg ] = self .registers .flat
303+ return type (self ).splat (
304+ reg , self .shape , new_layout , is_signed = self .is_signed
305+ )
306+
286307 def _pointwise (self , op , * other , output_is_signed : bool | None = None ):
287308 is_signed = (
288309 output_is_signed if output_is_signed is not None else self .is_signed
0 commit comments