Skip to content

Commit 78fd642

Browse files
committed
fix: neural gcm enable correctness check
1 parent 66133db commit 78fd642

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#include "llvm/ADT/SmallSet.h"
5454

5555
#include "llvm/ADT/MapVector.h"
56+
#include <cassert>
5657
#include <cstddef>
5758
#include <iterator>
5859
#include <numeric>
@@ -13045,6 +13046,12 @@ struct BroadcastInDimOpCanon final
1304513046
// Eliminate redundant nested BroadcastInDim.
1304613047
if (auto definingOp =
1304713048
operand.getDefiningOp<stablehlo::BroadcastInDimOp>()) {
13049+
DenseElementsAttr denseAttr;
13050+
if (matchPattern(definingOp.getOperand(), m_Constant(&denseAttr)) &&
13051+
!denseAttr.isSplat()) {
13052+
// TODO: investigate why this leads to incorrect results
13053+
return failure();
13054+
}
1304813055
auto newIndices = llvm::to_vector(
1304913056
llvm::map_range(definingOp.getBroadcastDimensions(),
1305013057
[&dims](int64_t dim) { return dims[dim]; }));

test/neuralgcm_test.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ def forward(initial_state, all_forcings):
111111
self.atol = 5e-2
112112
self.rtol = 1e-2
113113

114-
# TODO: we should fix this at some point
115-
self.skip_test_assert = True
116-
117114

118115
if __name__ == "__main__":
119116
from test_utils import fix_paths

test/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def fix_paths():
5151
# https://github.com/jax-ml/jax/blob/af36ae2cd783aea9eaa7979170df760a52542fcd/jax/_src/lib/__init__.py#L185
5252
os.environ["PYTHON_RUNFILES"] = runfiles
5353
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
54-
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.95"
54+
# os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
5555

5656
cuda_version = 12
5757
cuda_postfix = "_cu12"
@@ -419,8 +419,8 @@ def pipelines():
419419
setup_backends()
420420

421421
return [
422-
get_pipeline("JaxPipe"),
423422
get_pipeline("Jax"),
423+
get_pipeline("JaxPipe"),
424424
get_pipeline("HLOOpt"),
425425
get_pipeline("PartOpt"),
426426
get_pipeline("IPartOpt"),

0 commit comments

Comments
 (0)