@@ -977,8 +977,6 @@ def : HLODerivative<"Atan2Op", (Op $x, $y),
977977 (CheckedDiv (Sub (Mul $x, (Shadow $y)), (Mul $y, (Shadow $x))), (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y))))
978978 >;
979979
980- def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;
981-
982980// Gather Adjoint
983981def ResultDimensionNumbers : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
984982 op.getDimensionNumbers();
@@ -1446,3 +1444,127 @@ def : HLODerivative<"FftOp",
14461444 Fft (Shadow $x), (FftType), (FftLength)
14471445 )
14481446>;
1447+
1448+ def getBroadcastDimensionsWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1449+ auto bcastDims = op.getBroadcastDimensions();
1450+ SmallVector<int64_t> newBcastDims;
1451+ for (auto dim : bcastDims) {
1452+ newBcastDims.push_back(to_i64(dim));
1453+ }
1454+ if (gutils->width > 1) {
1455+ newBcastDims.insert(newBcastDims.begin(), gutils->width);
1456+ }
1457+ getI64Attr(builder, newBcastDims);
1458+ }]>;
1459+
1460+ def BroadcastDimsToReductionDims : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1461+ SmallVector<int64_t> reduceDims;
1462+ auto outRank = cast<RankedTensorType>(op.getType()).getRank();
1463+ for (int64_t i = 0; i < outRank; i++) {
1464+ if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1465+ reduceDims.push_back(i);
1466+ }
1467+ }
1468+ if (gutils->width > 1) {
1469+ for (int64_t i = 0; i < reduceDims.size(); i++) {
1470+ reduceDims[i] += 1;
1471+ }
1472+ }
1473+ getI64Attr(builder, reduceDims);
1474+ }]>;
1475+
1476+ def ReduceAddNullValue : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1477+ auto operand = op->getOperand(0);
1478+ auto operandElemType = cast<RankedTensorType>(operand.getType()).getElementType();
1479+ auto bodyTy = RankedTensorType::get({}, operandElemType);
1480+ cast<AutoDiffTypeInterface>(
1481+ gutils->getShadowType(bodyTy)).createNullValue(builder, op.getLoc());
1482+ }]>;
1483+
1484+ def BroadcastDimensionsToInversePermutation : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1485+ auto outTy = cast<RankedTensorType>(op.getType());
1486+ SmallVector<int64_t> perm(outTy.getRank());
1487+ SmallVector<int64_t> mapping(outTy.getRank(), -1);
1488+ for (auto [i, dim] : llvm::enumerate(op.getBroadcastDimensions())) {
1489+ mapping[to_i64(dim)] = i;
1490+ }
1491+
1492+ int next = op.getBroadcastDimensions().size();
1493+ for (int i = 0; i < outTy.getRank(); i++) {
1494+ if (mapping[i] == -1) {
1495+ mapping[i] = next++;
1496+ }
1497+ }
1498+
1499+ for (int i = 0; i < outTy.getRank(); i++) {
1500+ perm[mapping[i]] = i;
1501+ }
1502+ getI64Attr(builder, perm);
1503+ }]>;
1504+
1505+ def InsertDeletedReduceDimsType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1506+ auto outTy = cast<RankedTensorType>(op.getType());
1507+ auto outShape = outTy.getShape();
1508+ SmallVector<int64_t> reshapeShape(outTy.getRank(), -1);
1509+ for (auto [i, sz] : llvm::enumerate(outShape)) {
1510+ if (!llvm::is_contained(op.getBroadcastDimensions(), i)) {
1511+ reshapeShape[i] = 1;
1512+ } else {
1513+ reshapeShape[i] = sz;
1514+ }
1515+ }
1516+
1517+ if (gutils->width > 1) {
1518+ reshapeShape.insert(reshapeShape.begin(), gutils->width);
1519+ }
1520+ RankedTensorType::get(reshapeShape, outTy.getElementType());
1521+ }]>;
1522+
1523+ def ReduceAdd : HLOInst<"ReduceOp", ")->getResult(0)", "createAddRegion(">;
1524+
1525+ def ResultTypeWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1526+ auto outTy = cast<RankedTensorType>(op.getType());
1527+ auto outShape = outTy.getShape();
1528+ SmallVector<int64_t> newShape(outShape.begin(), outShape.end());
1529+ if (gutils->width > 1) {
1530+ newShape.insert(newShape.begin(), gutils->width);
1531+ }
1532+ RankedTensorType::get(newShape, outTy.getElementType());
1533+ }]>;
1534+
1535+ def InputTypeWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
1536+ auto inTy = cast<RankedTensorType>(op.getOperand().getType());
1537+ auto inShape = inTy.getShape();
1538+ SmallVector<int64_t> newShape(inShape.begin(), inShape.end());
1539+ if (gutils->width > 1) {
1540+ newShape.insert(newShape.begin(), gutils->width);
1541+ }
1542+ RankedTensorType::get(newShape, inTy.getElementType());
1543+ }]>;
1544+
1545+ def : HLODerivative<"BroadcastInDimOp", (Op $x),
1546+ [
1547+ (
1548+ Reshape
1549+ (InputTypeWithBatch),
1550+ (
1551+ Transpose
1552+ (
1553+ Reshape
1554+ (InsertDeletedReduceDimsType),
1555+ (ReduceAdd (DiffeRet), (ReduceAddNullValue), (BroadcastDimsToReductionDims))
1556+ ),
1557+ (BroadcastDimensionsToInversePermutation)
1558+ )
1559+ )
1560+ ],
1561+ (
1562+ SelectIfActive $x,
1563+ (
1564+ BroadcastInDim
1565+ (ResultTypeWithBatch),
1566+ (Shadow $x),
1567+ (getBroadcastDimensionsWithBatch)
1568+ ),
1569+ (HLOConstantFP<"0">)
1570+ )>;
0 commit comments