@@ -173,6 +173,51 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
173173 return success ();
174174}
175175
176+ LogicalResult IsValidStoreMatrixParams (
177+ VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
178+ MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
179+ function_ref<InFlightDiagnostic()> emitError) {
180+
181+ if (!dataTy)
182+ if (subgroup_block_io || vecDirection || vecLength)
183+ return emitError () << " vec_length, vec_direction and subgroup_block_io "
184+ " are only allowed when result is a 1D VectorType." ;
185+ else
186+ return success ();
187+
188+ if (mdescTy.getRank () != 2 )
189+ return emitError () << " mem_desc must be 2D." ;
190+
191+ ArrayRef<int64_t > dataShape = dataTy.getShape ();
192+ ArrayRef<int64_t > mdescShape = mdescTy.getShape ();
193+
194+ if (dataShape.size () == 2 ) {
195+ if (subgroup_block_io || vecDirection || vecLength)
196+ return emitError () << " vec_length, vec_direction and subgroup_block_io "
197+ " are only allowed when result is a 1D VectorType." ;
198+ if (llvm::any_of (llvm::zip_equal (dataShape, mdescShape),
199+ [](auto p) { return std::get<0 >(p) > std::get<1 >(p); }))
200+ return emitError () << " data shape must not exceed mem_desc shape." ;
201+ } else if (dataShape.size () == 1 ) {
202+
203+ SmallVector<int64_t > blockSize = mdescTy.getBlockSize ();
204+ // if the subgroup_block_io attribute is set, mdescTy must have block
205+ // attribute
206+ if (subgroup_block_io && !blockSize.size ())
207+ return emitError () << " mem_desc must have block attribute when "
208+ " subgroup_block_io is set." ;
209+ // if the subgroup_block_io attribute is set, the memdesc should be row
210+ // major
211+ if (subgroup_block_io && mdescTy.isColMajor ())
212+ return emitError () << " mem_desc should be row major when "
213+ " subgroup_block_io is set." ;
214+ } else if (dataShape.size () == 0 ) {
215+ return emitError () << " result shape must not be empty." ;
216+ }
217+
218+ return success ();
219+ }
220+
176221// ===----------------------------------------------------------------------===//
177222// XeGPU_CreateNdDescOp
178223// ===----------------------------------------------------------------------===//
@@ -1053,25 +1098,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
10531098 // nullptr/empty)
10541099 build (builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
10551100 /* vec_length=*/ nullptr , /* vec_direction=*/ nullptr ,
1056- /* subgroupBlockIO =*/ nullptr , layout);
1101+ /* subgroup_block_io =*/ nullptr , layout);
10571102}
10581103
10591104LogicalResult LoadMatrixOp::verify () {
1060- VectorType resTy = getRes ().getType ();
1061- MemDescType mdescTy = getMemDesc ().getType ();
1062-
1063- if (mdescTy.getRank () != 2 )
1064- return emitOpError (" mem_desc must be 2D." );
10651105
1066- ArrayRef<int64_t > valueShape = resTy.getShape ();
1067- ArrayRef<int64_t > mdescShape = mdescTy.getShape ();
1106+ auto resTy = dyn_cast<VectorType>(getRes ().getType ());
1107+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr ();
1108+ MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr ();
1109+ IntegerAttr vecLength = getVecLengthAttr ();
1110+ MemDescType mdescTy = getMemDesc ().getType ();
10681111
1069- if (valueShape.size () != 1 ) {
1070- if (llvm::any_of (llvm::zip_equal (valueShape, mdescShape),
1071- [](auto p) { return std::get<0 >(p) > std::get<1 >(p); }))
1072- return emitOpError (" result shape must not exceed mem_desc shape." );
1073- }
1074- return success ();
1112+ return IsValidStoreMatrixParams (resTy, mdescTy, subgroup_block_io,
1113+ vecDirection, vecLength,
1114+ [&]() { return emitError (); });
10751115}
10761116
10771117// ===----------------------------------------------------------------------===//
@@ -1086,24 +1126,20 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
10861126 dispatchIndexOpFoldResults (offsets, dynamicOffsets, staticOffsets);
10871127 auto staticOffsetsAttr = builder.getDenseI64ArrayAttr (staticOffsets);
10881128 build (builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1089- layout);
1129+ /* vec_length=*/ nullptr , /* vec_direction=*/ nullptr ,
1130+ /* subgroup_block_io=*/ nullptr , layout);
10901131}
10911132
10921133LogicalResult StoreMatrixOp::verify () {
1093- VectorType dataTy = getData ().getType ();
1094- MemDescType mdescTy = getMemDesc ().getType ();
10951134
1096- if (mdescTy.getRank () != 2 )
1097- return emitOpError (" mem_desc must be 2D." );
1098-
1099- ArrayRef<int64_t > dataShape = dataTy.getShape ();
1100- ArrayRef<int64_t > mdescShape = mdescTy.getShape ();
1101- if (dataShape.size () != 1 ) {
1102- if (llvm::any_of (llvm::zip_equal (dataShape, mdescShape),
1103- [](auto p) { return std::get<0 >(p) > std::get<1 >(p); }))
1104- return emitOpError (" data shape must not exceed mem_desc shape." );
1105- }
1106- return success ();
1135+ auto dataTy = dyn_cast<VectorType>(getData ().getType ());
1136+ UnitAttr subgroup_block_io = getSubgroupBlockIoAttr ();
1137+ MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr ();
1138+ IntegerAttr vecLength = getVecLengthAttr ();
1139+ MemDescType mdescTy = getMemDesc ().getType ();
1140+ return IsValidStoreMatrixParams (dataTy, mdescTy, subgroup_block_io,
1141+ vecDirection, vecLength,
1142+ [&]() { return emitError (); });
11071143}
11081144
11091145// ===----------------------------------------------------------------------===//
0 commit comments