@@ -373,42 +373,6 @@ def quantized_matmul(
373373 )
374374
375375
376- @linalg_structured_op
377- def matmul_transpose_a (
378- A = TensorDef (T1 , S .K , S .N ),
379- B = TensorDef (T2 , S .K , S .M ),
380- C = TensorDef (U , S .M , S .N , output = True ),
381- cast = TypeFnAttrDef (default = TypeFn .cast_signed ),
382- ):
383- """Performs a matrix multiplication of two 2D inputs with lhs operand
384- transposed.
385-
386- Numeric casting is performed on the operands to the inner multiply, promoting
387- them to the same data type as the accumulator/output.
388- """
389- domain (D .m , D .n , D .k )
390- implements (ContractionOpInterface )
391- C [D .m , D .n ] += cast (U , A [D .k , D .m ]) * cast (U , B [D .k , D .n ])
392-
393-
394- @linalg_structured_op
395- def matmul_transpose_b (
396- A = TensorDef (T1 , S .M , S .K ),
397- B = TensorDef (T2 , S .N , S .K ),
398- C = TensorDef (U , S .M , S .N , output = True ),
399- cast = TypeFnAttrDef (default = TypeFn .cast_signed ),
400- ):
401- """Performs a matrix multiplication of two 2D inputs with rhs operand
402- transposed.
403-
404- Numeric casting is performed on the operands to the inner multiply, promoting
405- them to the same data type as the accumulator/output.
406- """
407- domain (D .m , D .n , D .k )
408- implements (ContractionOpInterface )
409- C [D .m , D .n ] += cast (U , A [D .m , D .k ]) * cast (U , B [D .n , D .k ])
410-
411-
412376@linalg_structured_op
413377def mmt4d (
414378 lhs = TensorDef (TV .LhsType , S .M , S .K , S .M0 , S .K0 ),
@@ -453,44 +417,6 @@ def batch_mmt4d(
453417 ) * TypeFn .cast_signed (TV .AccumType , rhs [D .b , D .n , D .k , D .n0 , D .k0 ])
454418
455419
456- @linalg_structured_op
457- def batch_matmul_transpose_a (
458- A = TensorDef (T1 , Batch , S .K , S .M ),
459- B = TensorDef (T2 , Batch , S .K , S .N ),
460- C = TensorDef (U , Batch , S .M , S .N , output = True ),
461- ):
462- """Performs a batched matrix multiplication of two 3D inputs where lhs operand
463- has its non-batch dimensions transposed.
464-
465- Numeric casting is performed on the operands to the inner multiply, promoting
466- them to the same data type as the accumulator/output.
467- """
468- domain (D .b , D .m , D .n , D .k )
469- implements (ContractionOpInterface )
470- C [D .b , D .m , D .n ] += TypeFn .cast_signed (U , A [D .b , D .k , D .m ]) * TypeFn .cast_signed (
471- U , B [D .b , D .k , D .n ]
472- )
473-
474-
475- @linalg_structured_op
476- def batch_matmul_transpose_b (
477- A = TensorDef (T1 , Batch , S .M , S .K ),
478- B = TensorDef (T2 , Batch , S .N , S .K ),
479- C = TensorDef (U , Batch , S .M , S .N , output = True ),
480- ):
481- """Performs a batched matrix multiplication of two 3D inputs where rhs operand
482- has its non-batch dimensions transposed.
483-
484- Numeric casting is performed on the operands to the inner multiply, promoting
485- them to the same data type as the accumulator/output.
486- """
487- domain (D .b , D .m , D .n , D .k )
488- implements (ContractionOpInterface )
489- C [D .b , D .m , D .n ] += TypeFn .cast_signed (U , A [D .b , D .m , D .k ]) * TypeFn .cast_signed (
490- U , B [D .b , D .n , D .k ]
491- )
492-
493-
494420@linalg_structured_op
495421def quantized_batch_matmul (
496422 A = TensorDef (T1 , Batch , S .M , S .K ),
@@ -512,25 +438,6 @@ def quantized_batch_matmul(
512438 ) * (TypeFn .cast_signed (U , B [D .b , D .k , D .n ]) - TypeFn .cast_signed (U , BZp ))
513439
514440
515- @linalg_structured_op
516- def batch_reduce_matmul (
517- A = TensorDef (T1 , Batch , S .M , S .K ),
518- B = TensorDef (T2 , Batch , S .K , S .N ),
519- C = TensorDef (U , S .M , S .N , output = True ),
520- ):
521- """Performs a batch-reduce matrix multiplication of two 3D inputs.
522- The partial multiplication results are reduced into a 2D output.
523-
524- Numeric casting is performed on the operands to the inner multiply, promoting
525- them to the same data type as the accumulator/output.
526- """
527- domain (D .b , D .m , D .n , D .k )
528- implements (ContractionOpInterface )
529- C [D .m , D .n ] += TypeFn .cast_signed (U , A [D .b , D .m , D .k ]) * TypeFn .cast_signed (
530- U , B [D .b , D .k , D .n ]
531- )
532-
533-
534441@linalg_structured_op
535442def matvec (
536443 A = TensorDef (T1 , S .M , S .N ), y = TensorDef (T2 , S .N ), x = TensorDef (U , S .M , output = True )
0 commit comments