@@ -1459,15 +1459,31 @@ def getBroadcastDimensionsWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow
14591459
14601460def BroadcastDimsToReductionDims : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
14611461 SmallVector<int64_t> reduceDims;
1462- auto outRank = cast<RankedTensorType>(op.getType()).getRank();
1463- for (int64_t i = 0; i < outRank; i++) {
1464- if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1465- reduceDims.push_back(i);
1462+ auto outTy = cast<RankedTensorType>(op.getType());
1463+ auto bcastDims = op.getBroadcastDimensions();
1464+ auto inTy = cast<RankedTensorType>(op.getOperand().getType());
1465+
1466+ for (auto en : llvm::enumerate(outTy.getShape())) {
1467+ ssize_t bcastIdx = -1;
1468+ for (auto en2 : llvm::enumerate(bcastDims)) {
1469+ if (en2.value() == en.index()) {
1470+ bcastIdx = en2.index();
1471+ break;
1472+ }
14661473 }
1474+ if (bcastIdx != -1) {
1475+ if (en.value() != inTy.getShape()[bcastIdx]) {
1476+ reduceDims.push_back(en.index());
1477+ assert(inTy.getShape()[bcastIdx] == 1);
1478+ }
1479+ continue;
1480+ }
1481+ reduceDims.push_back(en.index());
14671482 }
1483+
14681484 if (gutils->width > 1) {
1469- for (int64_t i = 0; i < reduceDims.size(); i++) {
1470- reduceDims[i] += 1 ;
1485+ for (int i = 0; i < reduceDims.size(); i++) {
1486+ reduceDims[i]++ ;
14711487 }
14721488 }
14731489 getI64Attr(builder, reduceDims);
@@ -1503,11 +1519,33 @@ def BroadcastDimensionsToInversePermutation : GlobalExpr</*needsprimal*/0, /*nee
15031519}]>;
15041520
15051521def InsertDeletedReduceDimsType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1522+ SmallVector<int64_t> reduceDims;
15061523 auto outTy = cast<RankedTensorType>(op.getType());
1524+ auto bcastDims = op.getBroadcastDimensions();
1525+ auto inTy = cast<RankedTensorType>(op.getOperand().getType());
15071526 auto outShape = outTy.getShape();
1527+
1528+ for (auto en : llvm::enumerate(outTy.getShape())) {
1529+ ssize_t bcastIdx = -1;
1530+ for (auto en2 : llvm::enumerate(bcastDims)) {
1531+ if (en2.value() == en.index()) {
1532+ bcastIdx = en2.index();
1533+ break;
1534+ }
1535+ }
1536+ if (bcastIdx != -1) {
1537+ if (en.value() != inTy.getShape()[bcastIdx]) {
1538+ reduceDims.push_back(en.index());
1539+ assert(inTy.getShape()[bcastIdx] == 1);
1540+ }
1541+ continue;
1542+ }
1543+ reduceDims.push_back(en.index());
1544+ }
1545+
15081546 SmallVector<int64_t> reshapeShape(outTy.getRank(), -1);
15091547 for (auto [i, sz] : llvm::enumerate(outShape)) {
1510- if (! llvm::is_contained(op.getBroadcastDimensions() , i)) {
1548+ if (llvm::is_contained(reduceDims , i)) {
15111549 reshapeShape[i] = 1;
15121550 } else {
15131551 reshapeShape[i] = sz;
@@ -1559,12 +1597,5 @@ def : HLODerivative<"BroadcastInDimOp", (Op $x),
15591597 )
15601598 ],
15611599 (
1562- SelectIfActive $x,
1563- (
1564- BroadcastInDim
1565- (ResultTypeWithBatch),
1566- (Shadow $x),
1567- (getBroadcastDimensionsWithBatch)
1568- ),
1569- (HLOConstantFP<"0">)
1600+ BroadcastInDim (ResultTypeWithBatch), (Shadow $x), (getBroadcastDimensionsWithBatch)
15701601 )>;
0 commit comments