Skip to content

Commit ce3d659

Browse files
committed
Fixed a race condition that can occur when there is an unused tile inside an expression. This fix also provides a minor optimization where it avoids communication for unused tiles in certain circumstances.
1 parent 8013413 commit ce3d659

File tree

5 files changed

+76
-18
lines changed

5 files changed

+76
-18
lines changed

src/TiledArray/dist_eval/array_eval.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,13 @@ namespace TiledArray {
195195
}
196196
}
197197

198+
/// Discard a tile that is not needed
199+
200+
/// This function handles the cleanup for tiles that are not needed in
201+
/// subsequent computation.
202+
/// \param i The index of the tile
203+
virtual void discard_tile(size_type i) const { }
204+
198205
private:
199206

200207
value_type make_tile(const typename array_type::value_type& tile, const bool consume) const {

src/TiledArray/dist_eval/binary_eval.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ namespace TiledArray {
105105
return TensorImpl_::get_world().gop.template recv<value_type>(source, key);
106106
}
107107

108+
/// Discard a tile that is not needed
109+
110+
/// This function handles the cleanup for tiles that are not needed in
111+
/// subsequent computation.
112+
/// \param i The index of the tile
113+
virtual void discard_tile(size_type i) const { get_tile(i); }
114+
108115
private:
109116

110117
/// Task function for evaluating tiles
@@ -191,9 +198,9 @@ namespace TiledArray {
191198
} else {
192199
// Cleanup unused tiles
193200
if(! left_.is_zero(index))
194-
left_.get(index);
201+
left_.discard(index);
195202
if(! right_.is_zero(index))
196-
right_.get(index);
203+
right_.discard(index);
197204
}
198205
}
199206
}

src/TiledArray/dist_eval/contraction_eval.h

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -958,14 +958,22 @@ namespace TiledArray {
958958
const madness::Group row_group = make_row_group(k);
959959
const ProcessID group_root = get_row_group_root(k, row_group);
960960

961-
// Broadcast column k of left_.
962-
for(; index < left_end_; index += left_stride_local_) {
963-
if(left_.shape().is_zero(index)) continue;
964-
965-
// Broadcast the tile
966-
const madness::DistributedID key(DistEvalImpl_::id(), index);
967-
auto tile = get_tile(left_, index);
968-
TensorImpl_::get_world().gop.bcast(key, tile, group_root, row_group);
961+
if(row_group.size() > 1) {
962+
// Broadcast column k of left_.
963+
for(; index < left_end_; index += left_stride_local_) {
964+
if(left_.shape().is_zero(index)) continue;
965+
966+
// Broadcast the tile
967+
const madness::DistributedID key(DistEvalImpl_::id(), index);
968+
auto tile = get_tile(left_, index);
969+
TensorImpl_::get_world().gop.bcast(key, tile, group_root, row_group);
970+
}
971+
} else {
972+
// Discard column k of left_.
973+
for(; index < left_end_; index += left_stride_local_) {
974+
if(left_.shape().is_zero(index)) continue;
975+
left_.discard(index);
976+
}
969977
}
970978

971979
break;
@@ -993,14 +1001,22 @@ namespace TiledArray {
9931001
const madness::Group col_group = make_col_group(k);
9941002
const ProcessID group_root = get_col_group_root(k, col_group);
9951003

996-
// Broadcast row k of right_.
997-
for(; index < row_end; index += right_stride_local_) {
998-
if(right_.shape().is_zero(index)) continue;
1004+
if(col_group.size() > 1) {
1005+
// Broadcast row k of right_.
1006+
for(; index < row_end; index += right_stride_local_) {
1007+
if(right_.shape().is_zero(index)) continue;
9991008

1000-
// Broadcast the tile
1001-
const madness::DistributedID key(DistEvalImpl_::id(), index + left_.size());
1002-
auto tile = get_tile(right_, index);
1003-
TensorImpl_::get_world().gop.bcast(key, tile, group_root, col_group);
1009+
// Broadcast the tile
1010+
const madness::DistributedID key(DistEvalImpl_::id(), index + left_.size());
1011+
auto tile = get_tile(right_, index);
1012+
TensorImpl_::get_world().gop.bcast(key, tile, group_root, col_group);
1013+
}
1014+
} else {
1015+
// Broadcast row k of right_.
1016+
for(; index < row_end; index += right_stride_local_) {
1017+
if(right_.shape().is_zero(index)) continue;
1018+
right_.discard(index);
1019+
}
10041020
}
10051021

10061022
break;
@@ -1761,6 +1777,14 @@ namespace TiledArray {
17611777
return TensorImpl_::get_world().gop.template recv<value_type>(source, key);
17621778
}
17631779

1780+
1781+
/// Discard a tile that is not needed
1782+
1783+
/// This function handles the cleanup for tiles that are not needed in
1784+
/// subsequent computation.
1785+
/// \param i The index of the tile
1786+
virtual void discard_tile(size_type i) const { get_tile(i); }
1787+
17641788
private:
17651789

17661790
/// Adjust iteration depth based on memory constraints

src/TiledArray/dist_eval/dist_eval.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,15 @@ namespace TiledArray {
127127

128128
/// \param i The index of the tile
129129
/// \return Tile at index i
130-
/// \throw TiledArray::Exception When tile \c i is owned by a remote node.
131130
virtual Future<value_type> get_tile(size_type i) const = 0;
132131

132+
/// Discard a tile that is not needed
133+
134+
/// This function handles the cleanup for tiles that are not needed in
135+
/// subsequent computation.
136+
/// \param i The index of the tile
137+
virtual void discard_tile(size_type i) const = 0;
138+
133139
/// Set tensor value
134140

135141
/// This will store \c value at ordinal index \c i . Typically, this
@@ -322,6 +328,13 @@ namespace TiledArray {
322328
/// \return Tile \c i
323329
future get(size_type i) const { return pimpl_->get_tile(i); }
324330

331+
/// Discard a tile that is not needed
332+
333+
/// This function handles the cleanup for tiles that are not needed in
334+
/// subsequent computation.
335+
/// \param i The index of the tile
336+
virtual void discard(size_type i) const { pimpl_->discard_tile(i); }
337+
325338
/// World object accessor
326339

327340
/// \return A reference to the world object

src/TiledArray/dist_eval/unary_eval.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ namespace TiledArray {
8585
return TensorImpl_::get_world().gop.template recv<value_type>(source, key);
8686
}
8787

88+
/// Discard a tile that is not needed
89+
90+
/// This function handles the cleanup for tiles that are not needed in
91+
/// subsequent computation.
92+
/// \param i The index of the tile
93+
virtual void discard_tile(size_type i) const { get_tile(i); }
94+
8895
private:
8996

9097
/// Input tile argument type

0 commit comments

Comments
 (0)