File tree Expand file tree Collapse file tree 3 files changed +10
-6
lines changed
Expand file tree Collapse file tree 3 files changed +10
-6
lines changed Original file line number Diff line number Diff line change 5353#include "llvm/ADT/SmallSet.h"
5454
5555#include "llvm/ADT/MapVector.h"
56+ #include <cassert>
5657#include <cstddef>
5758#include <iterator>
5859#include <numeric>
@@ -13045,6 +13046,12 @@ struct BroadcastInDimOpCanon final
1304513046 // Eliminate redundant nested BroadcastInDim.
1304613047 if (auto definingOp =
1304713048 operand.getDefiningOp<stablehlo::BroadcastInDimOp>()) {
13049+ DenseElementsAttr denseAttr;
13050+ if (matchPattern(definingOp.getOperand(), m_Constant(&denseAttr)) &&
13051+ !denseAttr.isSplat()) {
13052+ // TODO: investigate why this leads to incorrect results
13053+ return failure();
13054+ }
1304813055 auto newIndices = llvm::to_vector(
1304913056 llvm::map_range(definingOp.getBroadcastDimensions(),
1305013057 [&dims](int64_t dim) { return dims[dim]; }));
Original file line number Diff line number Diff line change @@ -111,9 +111,6 @@ def forward(initial_state, all_forcings):
111111 self .atol = 5e-2
112112 self .rtol = 1e-2
113113
114- # TODO: we should fix this at some point
115- self .skip_test_assert = True
116-
117114
118115if __name__ == "__main__" :
119116 from test_utils import fix_paths
Original file line number Diff line number Diff line change @@ -51,7 +51,7 @@ def fix_paths():
5151 # https://github.com/jax-ml/jax/blob/af36ae2cd783aea9eaa7979170df760a52542fcd/jax/_src/lib/__init__.py#L185
5252 os .environ ["PYTHON_RUNFILES" ] = runfiles
5353 # https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
54- os .environ ["XLA_PYTHON_CLIENT_MEM_FRACTION" ] = "0.95 "
54+ # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9 "
5555
5656 cuda_version = 12
5757 cuda_postfix = "_cu12"
@@ -419,11 +419,11 @@ def pipelines():
419419 setup_backends ()
420420
421421 return [
422- get_pipeline ("JaxPipe" ),
423422 get_pipeline ("Jax" ),
424- get_pipeline ("HLOOpt " ),
423+ get_pipeline ("JaxPipe " ),
425424 get_pipeline ("PartOpt" ),
426425 get_pipeline ("IPartOpt" ),
426+ get_pipeline ("HLOOpt" ),
427427 get_pipeline ("DefOpt" ),
428428 get_pipeline ("IDefOpt" ),
429429 ]
You can’t perform that action at this time.
0 commit comments