@@ -931,14 +931,23 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
931931 return failure ();
932932
933933 VectorType newTy = type.cloneWith (*sgShape, type.getElementType ());
934- Value steps = vector::StepOp::create (rewriter, loc, newTy);
934+ auto steps = vector::StepOp::create (rewriter, loc, newTy);
935935 SmallVector<Value> newOps;
936936 for (auto offsets : *sgOffsets) {
937937 // Broadcast the offset scalar to a vector & add to the base steps
938- Value bcastOffset =
938+ auto bcastOffset =
939939 vector::BroadcastOp::create (rewriter, loc, newTy, offsets[0 ]);
940- Value finalSteps =
940+ auto finalSteps =
941941 arith::AddIOp::create (rewriter, loc, steps, bcastOffset);
942+ if (!layout.getLaneLayoutAsInt ().empty () ||
943+ !layout.getLaneDataAsInt ().empty ()) {
944+ xegpu::setDistributeLayoutAttr (steps->getResult (0 ),
945+ layout.dropSgLayoutAndData ());
946+ xegpu::setDistributeLayoutAttr (bcastOffset->getResult (0 ),
947+ layout.dropSgLayoutAndData ());
948+ xegpu::setDistributeLayoutAttr (finalSteps->getResult (0 ),
949+ layout.dropSgLayoutAndData ());
950+ }
942951 newOps.push_back (finalSteps);
943952 }
944953
@@ -971,14 +980,12 @@ struct WgToSgVectorShapeCastOp
971980 VectorType::get (sgShape, resultType.getElementType ());
972981
973982 // TODO: Add check for compatible layouts in layout attr.
974- // Only support ShapeCast which expands or reduces unit dims only.
975- // That is, only allow shape casts where the non-unit dimensions are
976- // preserved, and any added or removed dimensions must be of size 1.
977983 auto srcType = dyn_cast<VectorType>(adaptor.getSource ()[0 ].getType ());
978984 if (!srcType)
979985 return failure ();
980986
981- auto isUnitOrPreserved = [](ArrayRef<int64_t > src, ArrayRef<int64_t > dst) {
987+ // Check that shape_cast only adds/removes unit dimensions,
988+ auto onlyUnitDims = [](ArrayRef<int64_t > src, ArrayRef<int64_t > dst) {
982989 // Remove all 1s from both shapes and compare the rest.
983990 SmallVector<int64_t > srcNonUnit, dstNonUnit;
984991 for (int64_t d : src)
@@ -990,8 +997,8 @@ struct WgToSgVectorShapeCastOp
990997 return srcNonUnit == dstNonUnit;
991998 };
992999
993- if (!isUnitOrPreserved (srcType.getShape (), sgShape) ||
994- !isUnitOrPreserved (sgShape, srcType.getShape ()))
1000+ if (!onlyUnitDims (srcType.getShape (), sgShape) ||
1001+ !onlyUnitDims (sgShape, srcType.getShape ()))
9951002 return failure ();
9961003
9971004 SmallVector<Value> newShapeCastOps;
0 commit comments