Skip to content

Commit fe6d848

Browse files
authored
Merge pull request #492 from ValeevGroup/gaudel/fix/zero_elem_t_in_tot
Support non-zero ToTs with some zero inner Ts
2 parents b81da44 + bc69ec5 commit fe6d848

File tree

6 files changed

+71
-41
lines changed

6 files changed

+71
-41
lines changed

src/TiledArray/einsum/tiledarray.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,13 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
485485
// Step IV: C2(ijpq) -> C(ipjq)
486486

487487
auto sum_tot_2_tos = [](auto const &tot) {
488-
typename std::remove_reference_t<decltype(tot)>::value_type result(
489-
tot.range(), [tot](auto &&ix) { return tot(ix).sum(); });
488+
using tot_t = std::remove_reference_t<decltype(tot)>;
489+
typename tot_t::value_type result(
490+
tot.range(), [tot](auto &&ix) {
491+
if (!tot(ix).empty())
492+
return tot(ix).sum();
493+
else return typename tot_t::numeric_type{};
494+
});
490495
return result;
491496
};
492497

src/TiledArray/expressions/cont_engine.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,8 @@ class ContEngine : public BinaryEngine<Derived> {
513513
const left_tile_element_type& left,
514514
const right_tile_element_type& right) {
515515
contrreduce_op(result, left, right);
516-
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
516+
if (!TA::empty(result))
517+
result = contrreduce_op(result); // permutations of result are applied as "postprocessing"
517518
};
518519
} // ToT x ToT
519520
} else if (inner_prod == TensorProduct::Hadamard) {

src/TiledArray/tensor/kernels.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,8 @@ auto tensor_reduce(ReduceOp&& reduce_op, JoinOp&& join_op,
996996

997997
auto result = identity;
998998
for (std::remove_cv_t<decltype(volume)> ord = 0ul; ord < volume; ++ord) {
999+
if (tensor1.data()[ord].range().volume() == 0
1000+
|| ((tensors.data()[ord].range().volume() == 0) || ...)) continue;
9991001
auto temp = tensor_reduce(reduce_op, join_op, identity, tensor1.data()[ord],
10001002
tensors.data()[ord]...);
10011003
join_op(result, temp);

src/TiledArray/tensor/tensor.h

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,8 @@ class Tensor {
431431
auto volume = total_size();
432432
for (decltype(volume) i = 0; i < volume; ++i) {
433433
auto& el = *(data() + i);
434-
el = p(el, inner_perm);
434+
if (!el.empty())
435+
el = p(el, inner_perm);
435436
}
436437
}
437438
}
@@ -588,9 +589,13 @@ class Tensor {
588589
Tensor clone() const {
589590
Tensor result;
590591
if (data_) {
591-
result = detail::tensor_op<Tensor>(
592-
[](const numeric_type value) -> numeric_type { return value; },
593-
*this);
592+
if constexpr (detail::is_tensor_of_tensor_v<Tensor>) {
593+
result = Tensor(*this, [](value_type const& el) { return el.clone(); });
594+
} else {
595+
result = detail::tensor_op<Tensor>(
596+
[](const numeric_type value) -> numeric_type { return value; },
597+
*this);
598+
}
594599
} else if (range_) { // corner case: data_ = null implies range_.volume()
595600
// == 0;
596601
TA_ASSERT(range_.volume() == 0);
@@ -1538,6 +1543,7 @@ class Tensor {
15381543
detail::is_bipartite_permutation_v<Perm>;
15391544
// tile ops pass bipartite permutations here even if this is a plain tensor
15401545
if constexpr (!is_tot) {
1546+
if (empty()) return *this;
15411547
if constexpr (is_bperm) {
15421548
TA_ASSERT(inner_size(perm) == 0); // ensure this is a plain permutation
15431549
return Tensor(*this, op, outer(std::forward<Perm>(perm)));
@@ -1574,6 +1580,7 @@ class Tensor {
15741580
template <typename Scalar, typename std::enable_if<
15751581
detail::is_numeric_v<Scalar>>::type* = nullptr>
15761582
Tensor scale(const Scalar factor) const {
1583+
if (range().volume() == 0) return *this;
15771584
return unary([factor](const value_type& a) -> decltype(auto) {
15781585
using namespace TiledArray::detail;
15791586
return a * factor;
@@ -1626,6 +1633,10 @@ class Tensor {
16261633
return binary(
16271634
right,
16281635
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
1636+
if constexpr (detail::is_tensor_v<value_type>) {
1637+
if (l.empty() && r.empty())
1638+
return value_type{};
1639+
}
16291640
return l + r;
16301641
});
16311642
}
@@ -1740,6 +1751,7 @@ class Tensor {
17401751
template <typename Right,
17411752
typename std::enable_if<is_tensor<Right>::value>::type* = nullptr>
17421753
Tensor& add_to(const Right& right) {
1754+
if (right.empty()) return *this;
17431755
if (empty()) {
17441756
*this = Tensor{right.range(), value_type{}};
17451757
}
@@ -1923,11 +1935,17 @@ class Tensor {
19231935
typename std::enable_if<detail::is_nested_tensor_v<Right>>::type* =
19241936
nullptr>
19251937
decltype(auto) mult(const Right& right) const {
1926-
return binary(
1927-
right,
1928-
[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
1929-
return l * r;
1930-
});
1938+
1939+
auto mult_op =[](const value_type& l, const value_t<Right>& r) -> decltype(auto) {
1940+
return l * r;
1941+
};
1942+
1943+
if (empty() || right.empty()) {
1944+
using res_t = decltype(std::declval<Tensor>().binary(std::declval<Right>(), mult_op));
1945+
return res_t{};
1946+
}
1947+
1948+
return binary(right, mult_op);
19311949
}
19321950

19331951
/// Multiply this by \c right to create a new, permuted tensor

src/TiledArray/tile_op/contract_reduce.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,17 +326,17 @@ class ContractReduce : public ContractReduceBase<Result, Left, Right, Scalar> {
326326
/// \param[in] right The right-hand tile to be contracted
327327
void operator()(result_type& result, const first_argument_type& left,
328328
const second_argument_type& right) const {
329+
using TiledArray::empty;
330+
using TiledArray::gemm;
331+
if (empty(left) || empty(right)) return;
332+
329333
if constexpr (!ContractReduceBase_::plain_tensors) {
330334
TA_ASSERT(this->elem_muladd_op());
331335
// not yet implemented
332-
using TiledArray::empty;
333-
using TiledArray::gemm;
334336
gemm(result, left, right, ContractReduceBase_::gemm_helper(),
335337
this->elem_muladd_op());
336338
} else { // plain tensors
337339
TA_ASSERT(!this->elem_muladd_op());
338-
using TiledArray::empty;
339-
using TiledArray::gemm;
340340
if (empty(result))
341341
result = gemm(left, right, ContractReduceBase_::factor(),
342342
ContractReduceBase_::gemm_helper());

tests/retile.cpp

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,24 @@
66
BOOST_AUTO_TEST_SUITE(retile_suite)
77

88
BOOST_AUTO_TEST_CASE(retile_tensor) {
9-
TA::detail::matrix_il<double> some_values = {
10-
{0.1, 0.2, 0.3, 0.4, 0.5},
11-
{0.6, 0.7, 0.8, 0.9, 1.0},
12-
{1.1, 1.2, 1.3, 1.4, 1.5},
13-
{1.6, 1.7, 1.8, 1.9, 2.0},
14-
{2.1, 2.2, 2.3, 2.4, 2.5}
15-
};
16-
17-
auto range0 = TA::TiledRange1(0, 3, 5);
18-
auto range1 = TA::TiledRange1(0, 4, 5);
19-
auto trange = TA::TiledRange({range0, range1});
20-
21-
TA::TArrayD default_dense(*GlobalFixture::world, some_values);
22-
TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values);
23-
24-
auto result_dense = retile(default_dense, trange);
25-
auto result_sparse = retile(default_sparse, trange);
26-
27-
BOOST_CHECK_EQUAL(result_dense.trange(), trange);
28-
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
9+
TA::detail::matrix_il<double> some_values = {{0.1, 0.2, 0.3, 0.4, 0.5},
10+
{0.6, 0.7, 0.8, 0.9, 1.0},
11+
{1.1, 1.2, 1.3, 1.4, 1.5},
12+
{1.6, 1.7, 1.8, 1.9, 2.0},
13+
{2.1, 2.2, 2.3, 2.4, 2.5}};
14+
15+
auto range0 = TA::TiledRange1(0, 3, 5);
16+
auto range1 = TA::TiledRange1(0, 4, 5);
17+
auto trange = TA::TiledRange({range0, range1});
18+
19+
TA::TArrayD default_dense(*GlobalFixture::world, some_values);
20+
TA::TSpArrayD default_sparse(*GlobalFixture::world, some_values);
21+
22+
auto result_dense = retile(default_dense, trange);
23+
auto result_sparse = retile(default_sparse, trange);
24+
25+
BOOST_CHECK_EQUAL(result_dense.trange(), trange);
26+
BOOST_CHECK_EQUAL(result_sparse.trange(), trange);
2927
}
3028

3129
BOOST_AUTO_TEST_CASE(retile_more) {
@@ -69,17 +67,20 @@ BOOST_AUTO_TEST_CASE(retile_more) {
6967
return tile.norm();
7068
};
7169

70+
auto arr_source0 =
71+
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
72+
auto arr_target0 = TA::retile(arr_source0, tr_target);
73+
7274
auto get_elem = [](auto const& arr, auto const& eix) {
7375
auto tix = arr.trange().element_to_tile(eix);
7476
auto&& tile = arr.find(tix).get(false);
7577
return tile(eix);
7678
};
7779

78-
auto arr_source0 =
79-
TA::make_array<ArrayT>(world, tr_source, set_random_tensor_tile);
80-
auto arr_target0 = TA::retile(arr_source0, tr_target);
81-
8280
for (auto&& eix : elem_rng) {
81+
auto tix = arr_source0.trange().element_to_tile(eix);
82+
BOOST_REQUIRE(arr_source0.is_zero(tix) == arr_target0.is_zero(tix));
83+
if (arr_source0.is_zero(tix)) continue;
8384
BOOST_REQUIRE(get_elem(arr_source0, eix) == get_elem(arr_target0, eix));
8485
}
8586

@@ -94,8 +95,11 @@ BOOST_AUTO_TEST_CASE(retile_more) {
9495
world.gop.fence();
9596

9697
for (auto&& eix : elem_rng) {
98+
auto tix = arr_source.trange().element_to_tile(eix);
99+
BOOST_REQUIRE(arr_source.is_zero(tix) == arr_target.is_zero(tix));
100+
if (arr_source.is_zero(tix)) continue;
97101
BOOST_REQUIRE(get_elem(arr_source, eix) == get_elem(arr_target, eix));
98102
}
99103
}
100104

101-
BOOST_AUTO_TEST_SUITE_END()
105+
BOOST_AUTO_TEST_SUITE_END()

0 commit comments

Comments
 (0)