|
1 | | -diff --git a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc |
2 | | -index f8e9369..f1bece8 100644 |
3 | | ---- a/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc |
4 | | -+++ b/shardy/dialect/sdy/transforms/export/explicit_reshards_util.cc |
5 | | -@@ -330,8 +330,7 @@ int64_t getCommunicationCost(const ShardingProjection& shardingProjection, |
6 | | - OpShardingRuleAttr shardingRule, |
7 | | - ArrayRef<int64_t> tensorSizes, |
8 | | - ArrayRef<int64_t> localTensorSizes, MeshAttr mesh, |
9 | | -- const FactorAxesPair& factorAxesPair, |
10 | | -- const int64_t expandedShardingSize) { |
11 | | -+ const FactorAxesPair& factorAxesPair) { |
12 | | - // The relative cost of collective operations. |
13 | | - constexpr int64_t allToAllCost = 1; |
14 | | - constexpr int64_t collectivePermuteCost = 2; |
15 | | -@@ -408,14 +407,10 @@ int64_t getCommunicationCost(const ShardingProjection& shardingProjection, |
16 | | - // If the result contains this factor, we need |
17 | | - // 1. all-to-all to move AX from this factor to other factors. |
18 | | - // 2. all-gather to shrink the sharding size after the all-to-all above. |
19 | | -- for (const auto& [localTensorSize, tensorFactorSharding] : llvm::zip_equal( |
20 | | -+ for (const auto& [tensorSize, tensorFactorSharding] : llvm::zip_equal( |
21 | | - localTensorSizes.drop_front(shardingProjection.getNumOperands()), |
22 | | - shardingProjection.getResults())) { |
23 | | -- // A candidate factor axes (factorAxesPair) is guaranteed to be an expansion |
24 | | -- // of its existing sharding and `localTensorSize` has already taken into its |
25 | | -- // existing sharding. In order to avoid double counting, it needs to shard |
26 | | -- // further on the expanded sharding size only. |
27 | | -- int64_t shardedTensorSize = localTensorSize / expandedShardingSize; |
28 | | -+ int64_t shardedTensorSize = tensorSize / axesXSize; |
29 | | - auto [axesA, axesB] = getShardingAxesInOtherAndThisFactor( |
30 | | - tensorFactorSharding, factorAxesPair.factorIndex); |
31 | | - |
32 | | -@@ -533,16 +528,9 @@ class FactorAxesCandidateBag { |
33 | | - |
34 | | - FactorAxesCandidate bestCandidate; |
35 | | - for (FactorAxesCandidate& candidate : candidates) { |
36 | | -- // NOTE: The axes on replication factors are distributed to batching |
37 | | -- // dimensions after the common axes are found for all non-replication |
38 | | -- // factors. The communication cost calculation does not take this into |
39 | | -- // account yet and hence is not ready for cases that sharding rule has |
40 | | -- // replication factors. |
41 | | -- if (shardingRule.getNeedReplicationFactors().empty()) { |
42 | | -- candidate.communicationCost = getCommunicationCost( |
43 | | -- shardingProjection, shardingRule, tensorSizes, localTensorSizes, |
44 | | -- mesh, candidate.factorAxes, candidate.shardingSize); |
45 | | -- } |
46 | | -+ candidate.communicationCost = |
47 | | -+ getCommunicationCost(shardingProjection, shardingRule, tensorSizes, |
48 | | -+ localTensorSizes, mesh, candidate.factorAxes); |
49 | | - if (isValid(candidate)) { |
50 | | - bestCandidate = std::max(bestCandidate, candidate); |
51 | | - } |
52 | | -diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/cholesky_triangular_solve.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/cholesky_triangular_solve.mlir |
53 | | -index 8b9401e..bb14074 100644 |
54 | | ---- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/cholesky_triangular_solve.mlir |
55 | | -+++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/cholesky_triangular_solve.mlir |
56 | | -@@ -147,12 +147,12 @@ func.func @cholesky_cholesky_dims_shardings_can_merge(%arg0: tensor<16x8x8x8xf32 |
57 | | - return %0 : tensor<16x8x8x8xf32> |
58 | | - } |
59 | | - |
60 | | -+// TODO(zixuanjiang). We may want to keep 'x' due to its larger size. |
61 | | - // CHECK-LABEL: func @cholesky_sharded_cholesky_dim_input_only_batch_dim_both_but_input_sharding_larger |
62 | | - func.func @cholesky_sharded_cholesky_dim_input_only_batch_dim_both_but_input_sharding_larger(%arg0: tensor<8x4x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"x"}, {}, {}, {"z"}]>}) -> (tensor<8x4x8x8xf32> {sdy.sharding = #sdy.sharding<@mesh_xyz, [{"y"}, {}, {}, {}]>}){ |
63 | | -- // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh_xyz, [{"x"}, {}, {}, {}]> : tensor<8x4x8x8xf32> |
64 | | -- // CHECK-NEXT: %[[CHOLESKY:.*]] = stablehlo.cholesky %[[RESHARD1]], lower = true {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"x"}, {}, {}, {}]>]>} : tensor<8x4x8x8xf32> |
65 | | -- // CHECK-NEXT: %[[RESHARD2:.*]] = sdy.reshard %[[CHOLESKY]] <@mesh_xyz, [{"y"}, {}, {}, {}]> : tensor<8x4x8x8xf32> |
66 | | -- // CHECK-NEXT: return %[[RESHARD2]] : tensor<8x4x8x8xf32> |
67 | | -+ // CHECK: %[[RESHARD1:.*]] = sdy.reshard %arg0 <@mesh_xyz, [{"y"}, {}, {}, {}]> : tensor<8x4x8x8xf32> |
68 | | -+ // CHECK-NEXT: %[[CHOLESKY:.*]] = stablehlo.cholesky %[[RESHARD1]], lower = true {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"y"}, {}, {}, {}]>]>} : tensor<8x4x8x8xf32> |
69 | | -+ // CHECK-NEXT: return %[[CHOLESKY]] : tensor<8x4x8x8xf32> |
70 | | - %0 = stablehlo.cholesky %arg0, lower = true {sdy.sharding = #sdy.sharding_per_value<[<@mesh_xyz, [{"y"}, {}, {}, {}]>]>} : (tensor<8x4x8x8xf32>) -> tensor<8x4x8x8xf32> |
71 | | - return %0 : tensor<8x4x8x8xf32> |
72 | | - } |
73 | | -diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dynamic_slice_dynamic_update_slice.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dynamic_slice_dynamic_update_slice.mlir |
74 | | -index 904d776..c12086a 100644 |
75 | | ---- a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dynamic_slice_dynamic_update_slice.mlir |
76 | | -+++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards/dynamic_slice_dynamic_update_slice.mlir |
77 | | -@@ -42,10 +42,10 @@ func.func @dynamic_update_slice(%arg0: tensor<32x4x8xf32> {sdy.sharding = #sdy.s |
78 | | - |
79 | | - // CHECK-LABEL: func @dynamic_update_slice_different_input_and_output_sharding |
80 | | - func.func @dynamic_update_slice_different_input_and_output_sharding(%arg0: tensor<32x4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"x"}, {"y"}]>}, %arg1: tensor<32x1x2xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}, {"y"}]>}, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>) -> (tensor<32x4x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"y"}, {"x"}]>}){ |
81 | | -- // CHECK-NEXT: %[[RESHARD0:.*]] = sdy.reshard %arg0 <@mesh, [{}, {"y"}, {"x"}]> : tensor<32x4x8xf32> |
82 | | -- // CHECK-NEXT: %[[RESHARD1:.*]] = sdy.reshard %arg1 <@mesh, [{}, {}, {}]> : tensor<32x1x2xf32> |
83 | | -- // CHECK-NEXT: %[[DYNAMIC_UPDATE_SLICE:.*]] = stablehlo.dynamic_update_slice %[[RESHARD0]], %[[RESHARD1]], %arg2, %arg3, %arg4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {"x"}]>]>} : (tensor<32x4x8xf32>, tensor<32x1x2xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32x4x8xf32> |
84 | | -- // CHECK-NEXT: return %[[DYNAMIC_UPDATE_SLICE]] : tensor<32x4x8xf32> |
85 | | -+ // CHECK-NEXT: %[[REPLICATED_UPDATE:.*]] = sdy.reshard %arg1 <@mesh, [{}, {}, {}]> : tensor<32x1x2xf32> |
86 | | -+ // CHECK-NEXT: %[[DYNAMIC_UPDATE_SLICE:.*]] = stablehlo.dynamic_update_slice %arg0, %[[REPLICATED_UPDATE]], %arg2, %arg3, %arg4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"x"}, {"y"}]>]>} : (tensor<32x4x8xf32>, tensor<32x1x2xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32x4x8xf32> |
87 | | -+ // CHECK-NEXT: %[[RESHARD:.*]] = sdy.reshard %[[DYNAMIC_UPDATE_SLICE]] <@mesh, [{}, {"y"}, {"x"}]> : tensor<32x4x8xf32> |
88 | | -+ // CHECK-NEXT: return %[[RESHARD]] : tensor<32x4x8xf32> |
89 | | - %0 = stablehlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3, %arg4 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"y"}, {"x"}]>]>} : (tensor<32x4x8xf32>, tensor<32x1x2xf32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<32x4x8xf32> |
90 | | - return %0 : tensor<32x4x8xf32> |
91 | | - } |
92 | 1 | diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch |
93 | | -index 8327093..e25178e 100644 |
| 2 | +index e25178e..37a7256 100644 |
94 | 3 | --- a/third_party/llvm/generated.patch |
95 | 4 | +++ b/third_party/llvm/generated.patch |
96 | | -@@ -1,20 +1,17 @@ |
| 5 | +@@ -1,17 +1,29 @@ |
97 | 6 | Auto generated patch. Do not edit or delete it, even if empty. |
98 | | --diff -ruN --strip-trailing-cr a/utils/bazel/llvm-project-overlay/clang/BUILD.bazel b/utils/bazel/llvm-project-overlay/clang/BUILD.bazel |
99 | | ----- a/utils/bazel/llvm-project-overlay/clang/BUILD.bazel |
100 | | --+++ b/utils/bazel/llvm-project-overlay/clang/BUILD.bazel |
101 | | --@@ -1563,7 +1563,6 @@ |
102 | | -- ":basic", |
103 | | -- ":config", |
104 | | -- ":driver_options_inc_gen", |
105 | | --- ":frontend", |
106 | | -- ":lex", |
107 | | -- ":options", |
108 | | -- ":parse", |
109 | | --@@ -1719,6 +1718,7 @@ |
110 | | -- ":ast", |
111 | | -- ":basic", |
112 | | -- ":config", |
113 | | --+ ":driver", |
114 | | -- ":driver_options_inc_gen", |
115 | | -- ":edit", |
116 | | -- ":lex", |
117 | | -+diff -ruN --strip-trailing-cr a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp |
118 | | -+--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp |
119 | | -++++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp |
120 | | -+@@ -4095,6 +4095,12 @@ |
121 | | -+ llvm::SmallVector<size_t> occludedChildren; |
122 | | -+ llvm::sort( |
123 | | -+ indices.begin(), indices.end(), [&](const size_t a, const size_t b) { |
124 | | -++ // Bail early if we are asked to look at the same index. If we do not |
125 | | -++ // bail early, we can end up mistakenly adding indices to |
126 | | -++ // occludedChildren. This can occur with some types of libc++ hardening. |
127 | | -++ if (a == b) |
128 | | -++ return false; |
| 7 | +-diff -ruN --strip-trailing-cr a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp |
| 8 | +---- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp |
| 9 | +-+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp |
| 10 | +-@@ -4095,6 +4095,12 @@ |
| 11 | +- llvm::SmallVector<size_t> occludedChildren; |
| 12 | +- llvm::sort( |
| 13 | +- indices.begin(), indices.end(), [&](const size_t a, const size_t b) { |
| 14 | +-+ // Bail early if we are asked to look at the same index. If we do not |
| 15 | +-+ // bail early, we can end up mistakenly adding indices to |
| 16 | +-+ // occludedChildren. This can occur with some types of libc++ hardening. |
| 17 | +-+ if (a == b) |
| 18 | +-+ return false; |
| 19 | +-+ |
| 20 | +- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); |
| 21 | +- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); |
| 22 | ++diff -ruN --strip-trailing-cr a/clang/lib/Analysis/ThreadSafety.cpp b/clang/lib/Analysis/ThreadSafety.cpp |
| 23 | ++--- a/clang/lib/Analysis/ThreadSafety.cpp |
| 24 | +++++ b/clang/lib/Analysis/ThreadSafety.cpp |
| 25 | ++@@ -2820,7 +2820,7 @@ |
| 26 | ++ case CFGElement::AutomaticObjectDtor: { |
| 27 | ++ CFGAutomaticObjDtor AD = BI.castAs<CFGAutomaticObjDtor>(); |
| 28 | ++ const auto *DD = AD.getDestructorDecl(AC.getASTContext()); |
| 29 | ++- if (!DD->hasAttrs()) |
| 30 | +++ if (!DD || !DD->hasAttrs()) |
| 31 | ++ break; |
| 32 | + |
| 33 | ++ LocksetBuilder.handleCall( |
| 34 | ++diff -ruN --strip-trailing-cr a/clang/test/SemaCXX/no-warn-thread-safety-analysis.cpp b/clang/test/SemaCXX/no-warn-thread-safety-analysis.cpp |
| 35 | ++--- a/clang/test/SemaCXX/no-warn-thread-safety-analysis.cpp |
| 36 | +++++ b/clang/test/SemaCXX/no-warn-thread-safety-analysis.cpp |
| 37 | ++@@ -0,0 +1,12 @@ |
| 38 | +++// RUN: %clang_cc1 -fsyntax-only -verify -std=c++11 -Wthread-safety -Wthread-safety-pointer -Wthread-safety-beta -Wno-thread-safety-negative -fcxx-exceptions -DUSE_CAPABILITY=0 %s |
| 39 | +++// RUN: %clang_cc1 -fsyntax-only -verify -std=c++11 -Wthread-safety -Wthread-safety-pointer -Wthread-safety-beta -Wno-thread-safety-negative -fcxx-exceptions -DUSE_CAPABILITY=1 %s |
| 40 | +++// RUN: %clang_cc1 -fsyntax-only -verify -std=c++17 -Wthread-safety -Wthread-safety-pointer -Wthread-safety-beta -Wno-thread-safety-negative -fcxx-exceptions -DUSE_CAPABILITY=0 %s |
| 41 | +++// RUN: %clang_cc1 -fsyntax-only -verify -std=c++17 -Wthread-safety -Wthread-safety-pointer -Wthread-safety-beta -Wno-thread-safety-negative -fcxx-exceptions -DUSE_CAPABILITY=1 %s |
| 42 | +++// expected-no-diagnostics |
129 | 43 | ++ |
130 | | -+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); |
131 | | -+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); |
132 | | -+ |
| 44 | +++struct foo { |
| 45 | +++ ~foo(); |
| 46 | +++}; |
| 47 | +++struct bar : foo {}; |
| 48 | +++struct baz : bar {}; |
| 49 | +++baz foobar(baz a) { return a; } |
133 | 50 | diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl |
134 | | -index 8e6cbc9..215fe72 100644 |
| 51 | +index 215fe72..3b35979 100644 |
135 | 52 | --- a/third_party/llvm/workspace.bzl |
136 | 53 | +++ b/third_party/llvm/workspace.bzl |
137 | 54 | @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") |
138 | 55 |
|
139 | 56 | def repo(name): |
140 | 57 | """Imports LLVM.""" |
141 | | -- LLVM_COMMIT = "dea330b38d9c18b68219abdb52baaa72c9f1103d" |
142 | | -- LLVM_SHA256 = "0f00dd4e0d61e49051b09169450af0c5ca364bf7e3f015794089455ae8c8555c" |
143 | | -+ LLVM_COMMIT = "26362c68579dd4375198aae4651b4d5f8a36c715" |
144 | | -+ LLVM_SHA256 = "1b81809d98940d0a6d4f19ef9e0bf72cd5847b9bbed47bc3517fcf8a40d38fd9" |
| 58 | +- LLVM_COMMIT = "26362c68579dd4375198aae4651b4d5f8a36c715" |
| 59 | +- LLVM_SHA256 = "1b81809d98940d0a6d4f19ef9e0bf72cd5847b9bbed47bc3517fcf8a40d38fd9" |
| 60 | ++ LLVM_COMMIT = "4f39a4ff0ada92870ca1c2dccad382ea04947da8" |
| 61 | ++ LLVM_SHA256 = "264c7cc3e166b840911494a2e94cff2ae8730b56239bd91dc8f65c8dc9468262" |
145 | 62 |
|
146 | 63 | tf_http_archive( |
147 | 64 | name = name, |
0 commit comments