Skip to content

Commit 7026eb7

Browse files
authored
Merge branch 'main' into wfelix_xla_dev
2 parents cb3b0c4 + 9faee26 commit 7026eb7

File tree

6 files changed

+73
-16
lines changed

6 files changed

+73
-16
lines changed

.github/workflows/test-gb-25.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ jobs:
5252
- 'mg/sharded-factors'
5353
# - '0123456789abcdef0123456789abcdef01234567'
5454
reactant_commit:
55-
- 'main'
55+
- 'ap/persistent_compile_cache'
5656
# - "ap/updated_no_nan"
5757

5858
steps:
@@ -189,12 +189,12 @@ jobs:
189189
timeout-minutes: 60
190190
run: |
191191
export XLA_FLAGS='--xla_dump_to=${{ env.GB25_DIR }}/xla_dump'
192-
timeout --signal=TERM --verbose 59m mpiexecjl -np 1 julia --color=yes --project -O0 sharding/sharded_baroclinic_instability_simulation_run.jl
192+
timeout --signal=TERM --verbose 59m mpiexecjl -np 1 julia --color=yes --project -O0 --startup-file=no --threads=16 --compiled-modules=strict sharding/sharded_baroclinic_instability_simulation_run.jl
193193
working-directory: ${{ env.GB25_DIR }}
194194
- name: Test correctness in GB-25 code
195195
timeout-minutes: 20
196196
run: |
197-
timeout --signal=TERM --verbose 19m mpiexecjl -np 1 julia --color=yes --project -O0 correctness/correctness_sharded_baroclinic_instability_simulation_run.jl
197+
timeout --signal=TERM --verbose 19m mpiexecjl -np 1 julia --color=yes --project -O0 --startup-file=no --threads=16 --compiled-modules=strict correctness/correctness_sharded_baroclinic_instability_simulation_run.jl
198198
working-directory: ${{ env.GB25_DIR }}
199199
- name: Upload MLIR and XLA modules
200200
uses: actions/upload-artifact@v4

src/enzyme_ad/jax/Passes/AffineCFG.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2344,7 +2344,7 @@ struct MoveSelectToAffine : public OpRewritePattern<arith::SelectOp> {
23442344

23452345
bool changed = false;
23462346
auto condOp = ifOp.getCondition().getDefiningOp();
2347-
if (isa<AndIOp, OrIOp>(condOp)) {
2347+
if (condOp && isa<AndIOp, OrIOp>(condOp)) {
23482348
// condition, Negated
23492349

23502350
for (auto &opv : condOp->getOpOperands()) {

src/enzyme_ad/jax/Passes/AffineToStableHLORaising.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "mlir/IR/IRMapping.h"
3131
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
3232

33+
#include "Interfaces/AutoDiffTypeInterface.h"
3334
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
3435

3536
#include "src/enzyme_ad/jax/Dialect/Ops.h"
@@ -2529,6 +2530,7 @@ struct AffineToStableHLORaisingPass
25292530
for (auto arg : operands0) {
25302531

25312532
Attribute attr;
2533+
25322534
if (matchPattern(arg, m_Constant(&attr))) {
25332535
affine::AffineValueMap accessMap(AffineMap::get(arg.getContext()),
25342536
{});
@@ -2539,15 +2541,21 @@ struct AffineToStableHLORaisingPass
25392541
auto unrankedTensorType = RankedTensorType::get({}, ET);
25402542
OpBuilder builder(arg.getContext());
25412543
builder.setInsertionPointToEnd(newBlock);
2542-
auto newConst = builder.create<stablehlo::ConstantOp>(
2543-
arg.getLoc(), unrankedTensorType,
2544-
SplatElementsAttr::get(
2545-
unrankedTensorType,
2546-
ArrayRef<Attribute>(
2547-
isIndex ? IntegerAttr::get(
2548-
ET, cast<IntegerAttr>(attr).getValue())
2549-
: attr)));
2550-
auto newVal = newConst.getResult();
2544+
Value newVal;
2545+
if (arg.getDefiningOp<ub::PoisonOp>()) {
2546+
newVal = cast<mlir::enzyme::AutoDiffTypeInterface>(arg.getType())
2547+
.createNullValue(builder, arg.getLoc());
2548+
} else {
2549+
auto newConst = builder.create<stablehlo::ConstantOp>(
2550+
arg.getLoc(), unrankedTensorType,
2551+
SplatElementsAttr::get(
2552+
unrankedTensorType,
2553+
ArrayRef<Attribute>(
2554+
isIndex ? IntegerAttr::get(
2555+
ET, cast<IntegerAttr>(attr).getValue())
2556+
: attr)));
2557+
newVal = newConst.getResult();
2558+
}
25512559
mapping.map(arg, newVal);
25522560
maps[newVal] = accessMap;
25532561
continue;

src/enzyme_ad/jax/Passes/LowerJIT.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,22 @@
8080

8181
#include "mlir/Target/LLVMIR/Export.h"
8282

83+
#if (defined(_WIN32) || defined(__CYGWIN__)) && \
84+
!defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC)
85+
// Visibility annotations disabled.
86+
#define MLIR_CAPI_EXPORTED
87+
#elif defined(_WIN32) || defined(__CYGWIN__)
88+
// Windows visibility declarations.
89+
#if MLIR_CAPI_BUILDING_LIBRARY
90+
#define MLIR_CAPI_EXPORTED __declspec(dllexport)
91+
#else
92+
#define MLIR_CAPI_EXPORTED __declspec(dllimport)
93+
#endif
94+
#else
95+
// Non-windows: use visibility attributes.
96+
#define MLIR_CAPI_EXPORTED __attribute__((visibility("default")))
97+
#endif
98+
8399
#define DEBUG_TYPE "lower-jit"
84100

85101
namespace mlir {
@@ -274,7 +290,8 @@ bool initJIT() {
274290
return true;
275291
}
276292

277-
extern "C" void EnzymeJaXMapSymbol(const char *name, void *symbol) {
293+
extern "C" MLIR_CAPI_EXPORTED void EnzymeJaXMapSymbol(const char *name,
294+
void *symbol) {
278295
initJIT();
279296
MappedSymbols[JIT->mangleAndIntern(name)] = llvm::orc::ExecutorSymbolDef(
280297
llvm::orc::ExecutorAddr::fromPtr(symbol), llvm::JITSymbolFlags());

src/enzyme_ad/jax/cpu.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@
22
#include "xla/service/custom_call_target_registry.h"
33
#include <cstring>
44

5+
#if (defined(_WIN32) || defined(__CYGWIN__)) && \
6+
!defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC)
7+
// Visibility annotations disabled.
8+
#define MLIR_CAPI_EXPORTED
9+
#elif defined(_WIN32) || defined(__CYGWIN__)
10+
// Windows visibility declarations.
11+
#if MLIR_CAPI_BUILDING_LIBRARY
12+
#define MLIR_CAPI_EXPORTED __declspec(dllexport)
13+
#else
14+
#define MLIR_CAPI_EXPORTED __declspec(dllimport)
15+
#endif
16+
#else
17+
// Non-windows: use visibility attributes.
18+
#define MLIR_CAPI_EXPORTED __attribute__((visibility("default")))
19+
#endif
20+
521
template <bool withError> struct CallInfo;
622

723
template <> struct CallInfo<false> {
@@ -28,7 +44,7 @@ void forwarding_custom_call(void *out, const void **in, const void *opaque_ptr,
2844
}
2945
}
3046

31-
extern "C" void RegisterEnzymeXLACPUHandler() {
47+
extern "C" MLIR_CAPI_EXPORTED void RegisterEnzymeXLACPUHandler() {
3248
xla::CustomCallTargetRegistry::Global()->Register(
3349
"enzymexla_compile_cpu", (void *)&forwarding_custom_call<false>, "Host");
3450
xla::CustomCallTargetRegistry::Global()->Register(

src/enzyme_ad/jax/gpu.cc

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
11
#include "xla/ffi/api/ffi.h"
22
#include "xla/ffi/ffi_api.h"
33

4+
#if (defined(_WIN32) || defined(__CYGWIN__)) && \
5+
!defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC)
6+
// Visibility annotations disabled.
7+
#define MLIR_CAPI_EXPORTED
8+
#elif defined(_WIN32) || defined(__CYGWIN__)
9+
// Windows visibility declarations.
10+
#if MLIR_CAPI_BUILDING_LIBRARY
11+
#define MLIR_CAPI_EXPORTED __declspec(dllexport)
12+
#else
13+
#define MLIR_CAPI_EXPORTED __declspec(dllimport)
14+
#endif
15+
#else
16+
// Non-windows: use visibility attributes.
17+
#define MLIR_CAPI_EXPORTED __attribute__((visibility("default")))
18+
#endif
19+
420
template <bool withError> struct CallInfo;
521

622
template <> struct CallInfo<false> {
@@ -138,7 +154,7 @@ XLA_FFI_Error *execute(XLA_FFI_CallFrame *call_frame) {
138154
return nullptr;
139155
}
140156

141-
extern "C" void RegisterEnzymeXLAGPUHandler() {
157+
extern "C" MLIR_CAPI_EXPORTED void RegisterEnzymeXLAGPUHandler() {
142158
XLA_FFI_Handler_Bundle bundle = {instantiate, prepare, initialize<false>,
143159
execute<false>};
144160

0 commit comments

Comments
 (0)