@@ -173,17 +173,18 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
173173 return success ();
174174}
175175
176- LogicalResult IsValidStoreMatrixParams (
177- VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io ,
178- MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength ,
179- function_ref<InFlightDiagnostic()> emitError) {
180-
181- if (!dataTy)
182- if (subgroup_block_io || vecDirection || vecLength )
183- return emitError () << " vec_length, vec_direction and subgroup_block_io "
176+ LogicalResult
177+ IsValidStoreMatrixParams ( VectorType dataTy, MemDescType mdescTy,
178+ UnitAttr subgroup_block_io ,
179+ function_ref<InFlightDiagnostic()> emitError) {
180+
181+ if (!dataTy) {
182+ if (subgroup_block_io)
183+ return emitError () << " subgroup_block_io "
184184 " are only allowed when result is a 1D VectorType." ;
185185 else
186186 return success ();
187+ }
187188
188189 if (mdescTy.getRank () != 2 )
189190 return emitError () << " mem_desc must be 2D." ;
@@ -192,8 +193,8 @@ LogicalResult IsValidStoreMatrixParams(
192193 ArrayRef<int64_t > mdescShape = mdescTy.getShape ();
193194
194195 if (dataShape.size () == 2 ) {
195- if (subgroup_block_io || vecDirection || vecLength )
196- return emitError () << " vec_length, vec_direction and subgroup_block_io "
196+ if (subgroup_block_io)
197+ return emitError () << " subgroup_block_io "
197198 " are only allowed when result is a 1D VectorType." ;
198199 if (llvm::any_of (llvm::zip_equal (dataShape, mdescShape),
199200 [](auto p) { return std::get<0 >(p) > std::get<1 >(p); }))
@@ -1097,20 +1098,16 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
10971098 // Call the generated builder with all parameters (including optional ones as
10981099 // nullptr/empty)
10991100 build (builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
1100- /* vec_length=*/ nullptr , /* vec_direction=*/ nullptr ,
11011101 /* subgroup_block_io=*/ nullptr , layout);
11021102}
11031103
11041104LogicalResult LoadMatrixOp::verify () {
11051105
11061106 auto resTy = dyn_cast<VectorType>(getRes ().getType ());
11071107 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr ();
1108- MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr ();
1109- IntegerAttr vecLength = getVecLengthAttr ();
11101108 MemDescType mdescTy = getMemDesc ().getType ();
11111109
11121110 return IsValidStoreMatrixParams (resTy, mdescTy, subgroup_block_io,
1113- vecDirection, vecLength,
11141111 [&]() { return emitError (); });
11151112}
11161113
@@ -1126,19 +1123,15 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
11261123 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
11271124 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr (staticOffsets);
11281125 build (builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1129- /* vec_length=*/ nullptr , /* vec_direction=*/ nullptr ,
11301126 /* subgroup_block_io=*/ nullptr , layout);
11311127}
11321128
11331129LogicalResult StoreMatrixOp::verify () {
11341130
11351131 auto dataTy = dyn_cast<VectorType>(getData ().getType ());
11361132 UnitAttr subgroup_block_io = getSubgroupBlockIoAttr ();
1137- MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr ();
1138- IntegerAttr vecLength = getVecLengthAttr ();
11391133 MemDescType mdescTy = getMemDesc ().getType ();
11401134 return IsValidStoreMatrixParams (dataTy, mdescTy, subgroup_block_io,
1141- vecDirection, vecLength,
11421135 [&]() { return emitError (); });
11431136}
11441137
0 commit comments