|
32 | 32 | #include "stablehlo/dialect/ChloOps.h" |
33 | 33 | #include "stablehlo/dialect/StablehloOps.h" |
34 | 34 |
|
35 | | -#include "src/enzyme_ad/jax/Dialect/Dialect.h" |
36 | | -#include "src/enzyme_ad/jax/Dialect/Ops.h" |
37 | 35 | #include "src/enzyme_ad/jax/Implementations/WhileLoopInfo.h" |
38 | 36 | #include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" |
39 | 37 | #include "src/enzyme_ad/jax/Utils.h" |
@@ -1694,9 +1692,8 @@ struct SHLOGetDimensionSizeOpBatchInterface |
1694 | 1692 | auto bcastOp = BroadcastInDimOp::create( |
1695 | 1693 | builder, src->getLoc(), |
1696 | 1694 | RankedTensorType::get( |
1697 | | - batchSizes, |
1698 | | - cast<RankedTensorType>(newOp->getResult(0).getType()) |
1699 | | - .getElementType()), |
| 1695 | + batchSizes, cast<RankedTensorType>(newOp->getResult(0).getType()) |
| 1696 | + .getElementType()), |
1700 | 1697 | newOp->getResult(0), builder.getDenseI64ArrayAttr({})); |
1701 | 1698 | mapper.map(src->getResult(0), bcastOp->getResult(0)); |
1702 | 1699 | return success(); |
@@ -3949,8 +3946,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( |
3949 | 3946 | *context); |
3950 | 3947 |
|
3951 | 3948 | ConstantOp::attachInterface<SHLOConstantOpBatchInterface>(*context); |
3952 | | - GetDimensionSizeOp::attachInterface<SHLOGetDimensionSizeOpBatchInterface>( |
3953 | | - *context); |
3954 | 3949 | TransposeOp::attachInterface<SHLOTransposeOpBatchInterface>(*context); |
3955 | 3950 | IfOp::attachInterface<SHLOGenericBatchOpInterface<IfOp>>(*context); |
3956 | 3951 | WhileOp::attachInterface<SHLOGenericBatchOpInterface<WhileOp>>(*context); |
@@ -3980,9 +3975,5 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( |
3980 | 3975 |
|
3981 | 3976 | AddOp::attachInterface<StablehloAddSimplifyMathInterface>(*context); |
3982 | 3977 | SubtractOp::attachInterface<StablehloSubSimplifyMathInterface>(*context); |
3983 | | - |
3984 | | - // TODO: move into its own file |
3985 | | - enzymexla::JITCallOp::attachInterface< |
3986 | | - SHLOGenericBatchOpInterface<enzymexla::JITCallOp>>(*context); |
3987 | 3978 | }); |
3988 | 3979 | } |
0 commit comments