@@ -958,14 +958,22 @@ namespace TiledArray {
958958 const madness::Group row_group = make_row_group (k);
959959 const ProcessID group_root = get_row_group_root (k, row_group);
960960
961- // Broadcast column k of left_.
962- for (; index < left_end_; index += left_stride_local_) {
963- if (left_.shape ().is_zero (index)) continue ;
964-
965- // Broadcast the tile
966- const madness::DistributedID key (DistEvalImpl_::id (), index);
967- auto tile = get_tile (left_, index);
968- TensorImpl_::get_world ().gop .bcast (key, tile, group_root, row_group);
961+ if (row_group.size () > 1 ) {
962+ // Broadcast column k of left_.
963+ for (; index < left_end_; index += left_stride_local_) {
964+ if (left_.shape ().is_zero (index)) continue ;
965+
966+ // Broadcast the tile
967+ const madness::DistributedID key (DistEvalImpl_::id (), index);
968+ auto tile = get_tile (left_, index);
969+ TensorImpl_::get_world ().gop .bcast (key, tile, group_root, row_group);
970+ }
971+ } else {
972+ // Discard column k of left_.
973+ for (; index < left_end_; index += left_stride_local_) {
974+ if (left_.shape ().is_zero (index)) continue ;
975+ left_.discard (index);
976+ }
969977 }
970978
971979 break ;
@@ -993,14 +1001,22 @@ namespace TiledArray {
9931001 const madness::Group col_group = make_col_group (k);
9941002 const ProcessID group_root = get_col_group_root (k, col_group);
9951003
996- // Broadcast row k of right_.
997- for (; index < row_end; index += right_stride_local_) {
998- if (right_.shape ().is_zero (index)) continue ;
1004+ if (col_group.size () > 1 ) {
1005+ // Broadcast row k of right_.
1006+ for (; index < row_end; index += right_stride_local_) {
1007+ if (right_.shape ().is_zero (index)) continue ;
9991008
1000- // Broadcast the tile
1001- const madness::DistributedID key (DistEvalImpl_::id (), index + left_.size ());
1002- auto tile = get_tile (right_, index);
1003- TensorImpl_::get_world ().gop .bcast (key, tile, group_root, col_group);
1009+ // Broadcast the tile
1010+ const madness::DistributedID key (DistEvalImpl_::id (), index + left_.size ());
1011+ auto tile = get_tile (right_, index);
1012+ TensorImpl_::get_world ().gop .bcast (key, tile, group_root, col_group);
1013+ }
1014+ } else {
1015+ // Broadcast row k of right_.
1016+ for (; index < row_end; index += right_stride_local_) {
1017+ if (right_.shape ().is_zero (index)) continue ;
1018+ right_.discard (index);
1019+ }
10041020 }
10051021
10061022 break ;
@@ -1761,6 +1777,14 @@ namespace TiledArray {
17611777 return TensorImpl_::get_world ().gop .template recv <value_type>(source, key);
17621778 }
17631779
1780+
1781+ // / Discard a tile that is not needed
1782+
1783+ // / This function handles the cleanup for tiles that are not needed in
1784+ // / subsequent computation.
1785+ // / \param i The index of the tile
1786+ virtual void discard_tile (size_type i) const { get_tile (i); }
1787+
17641788 private:
17651789
17661790 // / Adjust iteration depth based on memory constraints
0 commit comments