Skip to content

Commit 5d9bf28

Browse files
committed
support for adding/subtracting constant
1 parent 9552625 commit 5d9bf28

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

btas/tensor.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,24 @@ namespace btas {
585585
return y; /* automatically called move semantics */
586586
}
587587

588+
/// adds a number to every element
589+
Tensor operator+(const value_type& x) const {
590+
Tensor y = this->clone();
591+
y += x;
592+
return y; /* automatically called move semantics */
593+
}
594+
595+
/// adds a number to every element
596+
Tensor& operator+=(const value_type& x) {
597+
using std::begin;
598+
using std::cbegin;
599+
using std::cend;
600+
std::transform(cbegin(storage_), cend(storage_), begin(storage_), [x](const auto& v) {
601+
return v+x;
602+
});
603+
return *this;
604+
}
605+
588606
/// subtraction assignment
589607
Tensor& operator-=(const Tensor& x) {
590608
using std::begin;
@@ -603,6 +621,24 @@ namespace btas {
603621
return y; /* automatically called move semantics */
604622
}
605623

624+
/// subtracts a number from every element
625+
Tensor operator-(const value_type& x) const {
626+
Tensor y = this->clone();
627+
y -= x;
628+
return y; /* automatically called move semantics */
629+
}
630+
631+
/// subtracts a number from every element
632+
Tensor& operator-=(const value_type& x) {
633+
using std::begin;
634+
using std::cbegin;
635+
using std::cend;
636+
std::transform(cbegin(storage_), cend(storage_), begin(storage_), [x](const auto& v) {
637+
return v-x;
638+
});
639+
return *this;
640+
}
641+
606642
/// \return bare const pointer to the first element of data_
607643
/// this enables to call BLAS functions
608644
const_pointer data() const {

unittest/tensor_test.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,14 @@ TEST_CASE("Tensor Operations") {
219219
CHECK_NOTHROW( static_cast<double>(T0) == 11);
220220
}
221221

222+
SECTION("Add/subtract constant") {
223+
T3.fill(1.);
224+
auto T3_plus_1 = T3 + 1.;
225+
for (auto x : T3_plus_1) CHECK(x == 2.);
226+
T3_plus_1 -= 1.;
227+
for (auto x : T3_plus_1) CHECK(x == 1.);
228+
}
229+
222230
SECTION("Generate") {
223231
std::vector<double> data(T3.size());
224232
for (auto& x : data) x = rng();

0 commit comments

Comments
 (0)