@@ -72,15 +72,23 @@ static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
7272 if (valTy == type)
7373 return val;
7474 auto srcVecTy = dyn_cast<VectorType>(valTy);
75+ auto dstVecTy = dyn_cast<VectorType>(type);
76+
7577 if (srcVecTy) {
76- auto dstVecTy = dyn_cast<VectorType>(type);
7778 assert (dstVecTy && " vector values cannot be forced into a non-vector type" );
78- assert (srcVecTy.getRank () == 1 && dstVecTy.getRank () == 1 &&
79- " only flat 1D vectors can be force casted" );
79+
80+ // Flatten source vector if it's not rank-1
81+ auto flatSrcVecTy = getFlattenedVectorType (srcVecTy);
82+ if (srcVecTy != flatSrcVecTy)
83+ val = builder.create <vector::ShapeCastOp>(loc, flatSrcVecTy, val);
84+
85+ // Flatten destination type if it's not rank-1
86+ auto flatDstVecTy = getFlattenedVectorType (dstVecTy);
87+
8088 int64_t dstVecLength =
81- dstVecTy .getElementTypeBitWidth () * dstVecTy .getShape ()[0 ];
89+ flatDstVecTy .getElementTypeBitWidth () * flatDstVecTy .getShape ()[0 ];
8290 int64_t srcVecLength =
83- srcVecTy .getElementTypeBitWidth () * srcVecTy .getShape ()[0 ];
91+ flatSrcVecTy .getElementTypeBitWidth () * flatSrcVecTy .getShape ()[0 ];
8492 if (srcVecLength != dstVecLength) {
8593 assert (srcVecLength < dstVecLength &&
8694 " only widening forced casts are supported" );
@@ -92,7 +100,19 @@ static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val,
92100 else
93101 val = widen256bVectorValueTo512b (builder, loc, val);
94102 }
103+
104+ // Bitcast to flat destination type (bitcast only supports flat vectors)
105+ val = bitcastValueToType (builder, loc, val, flatDstVecTy);
106+
107+ // Reshape back to original destination shape if needed
108+ if (flatDstVecTy != dstVecTy)
109+ val = builder.create <vector::ShapeCastOp>(loc, dstVecTy, val);
110+
111+ return val;
95112 }
113+
114+ // Non-vector types can be bitcast directly
115+ assert (!dstVecTy && " cannot force cast scalar to vector type" );
96116 return bitcastValueToType (builder, loc, val, type);
97117}
98118
@@ -280,9 +300,10 @@ class AddElemOpConversion
280300 return failure ();
281301 }
282302
283- // create bitcast for result
284- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
285- addElemOp);
303+ // create bitcast/shape_cast for result
304+ auto resultVal = forceCastValueToType (rewriter, loc, addElemOp,
305+ op.getResult ().getType ());
306+ rewriter.replaceOp (op, resultVal);
286307 return success ();
287308 }
288309};
@@ -643,9 +664,10 @@ class MulElemOpConversion
643664 /* variant=*/ 2 , /* zero_acc=*/ 0 , /* shift16=*/ 1 ,
644665 /* sub_mul=*/ 0 , /* sub_acc1=*/ 0 , /* sub_acc2=*/ 0 , /* sub_mask=*/ 0 ));
645666
646- // create bitcast for result
647- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
648- acc64Val);
667+ // create bitcast/shape_cast for result
668+ auto resultVal =
669+ forceCastValueToType (rewriter, loc, acc64Val, op.getResult ().getType ());
670+ rewriter.replaceOp (op, resultVal);
649671 return success ();
650672 }
651673
@@ -828,9 +850,10 @@ class MulElemOpConversion
828850 createMacOps (c, e, cfMul))))))));
829851 }
830852
831- // create bitcast for result
832- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
833- finalMacVal);
853+ // create bitcast/shape_cast for result
854+ auto resultVal = forceCastValueToType (rewriter, loc, finalMacVal,
855+ op.getResult ().getType ());
856+ rewriter.replaceOp (op, resultVal);
834857 return success ();
835858 }
836859
@@ -881,9 +904,10 @@ class MulElemOpConversion
881904 rewriter.getI32Type ()}));
882905 }
883906
884- // create bitcast for result
885- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
886- mulElemOp);
907+ // create bitcast/shape_cast for result
908+ auto resultVal = forceCastValueToType (rewriter, loc, mulElemOp,
909+ op.getResult ().getType ());
910+ rewriter.replaceOp (op, resultVal);
887911 return success ();
888912 }
889913};
@@ -1186,13 +1210,10 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::SRSOp> {
11861210 return failure ();
11871211 }
11881212
1189- // create bitcast for result if needed
1190- if (op.getResult ().getType () != srsIntrOp.getType ()) {
1191- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
1192- srsIntrOp);
1193- } else {
1194- rewriter.replaceOp (op, srsIntrOp);
1195- }
1213+ // create bitcast/shape_cast for result if needed
1214+ auto resultVal = forceCastValueToType (rewriter, loc, srsIntrOp,
1215+ op.getResult ().getType ());
1216+ rewriter.replaceOp (op, resultVal);
11961217
11971218 return success ();
11981219 }
@@ -1388,9 +1409,10 @@ class ConcatOpConversion
13881409 return failure ();
13891410 }
13901411
1391- // create bitcast for result
1392- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
1393- concatOp);
1412+ // create bitcast/shape_cast for result
1413+ auto resultVal =
1414+ forceCastValueToType (rewriter, loc, concatOp, op.getResult ().getType ());
1415+ rewriter.replaceOp (op, resultVal);
13941416
13951417 return success ();
13961418 }
@@ -1484,13 +1506,10 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ExtOp> {
14841506 return failure ();
14851507 }
14861508
1487- // create bitcast for result
1488- if (op.getResult ().getType () != extOp.getType ()) {
1489- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
1490- extOp);
1491- } else {
1492- rewriter.replaceOp (op, extOp);
1493- }
1509+ // create bitcast/shape_cast for result
1510+ auto resultVal =
1511+ forceCastValueToType (rewriter, loc, extOp, op.getResult ().getType ());
1512+ rewriter.replaceOp (op, resultVal);
14941513
14951514 return success ();
14961515 }
@@ -1964,9 +1983,10 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern<aievec::ShiftOp> {
19641983 rewriter.getI32Type (), rewriter.getI32Type ()}));
19651984 }
19661985
1967- // create bitcast for result
1968- rewriter.replaceOpWithNewOp <LLVM::BitcastOp>(op, op.getResult ().getType (),
1969- shiftOp);
1986+ // create bitcast/shape_cast for result
1987+ auto resultVal =
1988+ forceCastValueToType (rewriter, loc, shiftOp, op.getResult ().getType ());
1989+ rewriter.replaceOp (op, resultVal);
19701990
19711991 return success ();
19721992 }
0 commit comments