Skip to content

Commit ed1cd88

Browse files
authored
Merge pull request #129 from marty1885/apichange
more indexing method
2 parents ffb5cf8 + 958af35 commit ed1cd88

File tree

3 files changed

+32
-24
lines changed

3 files changed

+32
-24
lines changed

Etaler/Core/Tensor.hpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,15 @@ struct ETALER_EXPORT Tensor
211211
Tensor log() const { return backend()->log(pimpl()); }
212212
Tensor logical_not() const { return backend()->logical_not(pimpl()); }
213213

214-
Tensor add(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->add(a(), b()); }
215-
Tensor subtract(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->subtract(a(), b()); }
216-
Tensor mul(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->mul(a(), b()); }
217-
Tensor div(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->div(a(), b()); }
218-
Tensor equal(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->equal(a(), b()); }
219-
Tensor greater(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->greater(a(), b()); }
220-
Tensor lesser(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->lesser(a(), b()); }
221-
Tensor logical_and(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_and(a(), b()); }
222-
Tensor logical_or(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_or(a(), b()); }
214+
Tensor add(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->add(a.pimpl(), b.pimpl()); }
215+
Tensor subtract(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->subtract(a.pimpl(), b.pimpl()); }
216+
Tensor mul(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->mul(a.pimpl(), b.pimpl()); }
217+
Tensor div(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->div(a.pimpl(), b.pimpl()); }
218+
Tensor equal(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->equal(a.pimpl(), b.pimpl()); }
219+
Tensor greater(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->greater(a.pimpl(), b.pimpl()); }
220+
Tensor lesser(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->lesser(a.pimpl(), b.pimpl()); }
221+
Tensor logical_and(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_and(a.pimpl(), b.pimpl()); }
222+
Tensor logical_or(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_or(a.pimpl(), b.pimpl()); }
223223

224224
inline bool any() const { return cast(DType::Bool).sum(std::nullopt, DType::Bool).item<uint8_t>(); }
225225
inline bool all() const { return cast(DType::Bool).sum(std::nullopt).item<int32_t>() == int32_t(size()); }
@@ -252,14 +252,14 @@ struct ETALER_EXPORT Tensor
252252

253253
//Subscription operator
254254
Tensor operator [] (const IndexList& r) { return view(r); }
255+
template <typename ... Args>
256+
Tensor operator () (Args ... args) { return view({args ...}); }
255257

256258
Tensor sum(std::optional<intmax_t> dim=std::nullopt, DType dtype=DType::Unknown) const;
257259
Tensor abs() const { return backend()->abs(pimpl()); }
258260
bool isSame (const Tensor& other) const;
259261

260262
//Utils
261-
TensorImpl* operator () () {return pimpl();}
262-
const TensorImpl* operator () () const {return pimpl();}
263263

264264
using iterator = TensorIterator<Tensor>;
265265
using const_iterator = TensorIterator<const Tensor>;
@@ -332,18 +332,18 @@ inline Tensor cellActivity(const Tensor& x, const Tensor& connections, const Ten
332332
return x;
333333
return x.cast(DType::Bool);
334334
}();
335-
return x.backend()->cellActivity(input(), connections(), permeances(), connected_permeance, active_threshold, has_unconnected_synapse);
335+
return x.backend()->cellActivity(input.pimpl(), connections.pimpl(), permeances.pimpl(), connected_permeance, active_threshold, has_unconnected_synapse);
336336
}
337337

338338
inline void learnCorrilation(const Tensor& x, const Tensor& learn, const Tensor& connection
339339
, Tensor& permeances, float perm_inc, float perm_dec, bool has_unconnected_synapse=true)
340340
{
341-
x.backend()->learnCorrilation(x(), learn(), connection(), permeances(), perm_inc, perm_dec, has_unconnected_synapse);
341+
x.backend()->learnCorrilation(x.pimpl(), learn.pimpl(), connection.pimpl(), permeances.pimpl(), perm_inc, perm_dec, has_unconnected_synapse);
342342
}
343343

344344
inline Tensor globalInhibition(const Tensor& x, float fraction)
345345
{
346-
return x.backend()->globalInhibition(x(), fraction);
346+
return x.backend()->globalInhibition(x.pimpl(), fraction);
347347
}
348348

349349
Tensor inline cast(const Tensor& x, DType dtype)
@@ -358,27 +358,27 @@ inline Tensor copy(const Tensor& x)
358358

359359
inline void sortSynapse(Tensor& connection, Tensor& permeances)
360360
{
361-
connection.backend()->sortSynapse(connection(), permeances());
361+
connection.backend()->sortSynapse(connection.pimpl(), permeances.pimpl());
362362
}
363363

364364
inline Tensor burst(const Tensor& x, const Tensor& s)
365365
{
366-
return x.backend()->burst(x(), s());
366+
return x.backend()->burst(x.pimpl(), s.pimpl());
367367
}
368368

369369
inline Tensor reverseBurst(const Tensor& x)
370370
{
371-
return x.backend()->reverseBurst(x());
371+
return x.backend()->reverseBurst(x.pimpl());
372372
}
373373

374374
inline void growSynapses(const Tensor& x, const Tensor& y, Tensor& connections, Tensor& permeances, float init_perm)
375375
{
376-
x.backend()->growSynapses(x(), y(), connections(), permeances(), init_perm);
376+
x.backend()->growSynapses(x.pimpl(), y.pimpl(), connections.pimpl(), permeances.pimpl(), init_perm);
377377
}
378378

379379
inline void decaySynapses(Tensor& connections, Tensor& permeances, float threshold)
380380
{
381-
connections.backend()->decaySynapses(connections(), permeances(), threshold);
381+
connections.backend()->decaySynapses(connections.pimpl(), permeances.pimpl(), threshold);
382382
}
383383

384384
inline void assign(Tensor& x, const Tensor& y)
@@ -420,6 +420,10 @@ inline Tensor logical_or(const Tensor& x1, const Tensor& x2) { return x1.logical
420420
inline bool all(const Tensor& t) { return t.all(); }
421421
inline bool any(const Tensor& t) { return t.any(); }
422422

423+
template <typename ... Args>
424+
inline Tensor view(const Tensor& t, Args... args) { return t.view({args...}); }
425+
inline Tensor dynamic_view(const Tensor& t, const IndexList& indices) { return t.view(indices); }
426+
423427
inline Tensor zeros_like(const Tensor& x) { return zeros(x.shape(), x.dtype(), x.backend()); }
424428
inline Tensor ones_like(const Tensor& x) { return ones(x.shape(), x.dtype(), x.backend()); }
425429
}

docs/source/PythonBindings.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Python bindings
22

33
## PyEtaler
4-
[PyEtaler](https://guthub.com/etaler/pyetaler) is the offical binding for Etaler. We try to keep the Python API as close to the C++ one as possible. So you can use the C++ document as the Python document. With that said, some functions are changed in the binding to make it more Pythonic.
4+
[PyEtaler](https://github.com/etaler/pyetaler) is the offical binding for Etaler. We try to keep the Python API as close to the C++ one as possible. So you can use the C++ document as the Python document. With that said, some functions are changed in the binding to make it more Pythonic.
55

66
```python
77
>>> from etaler import et

tests/common_tests.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <Etaler/Algorithms/SDRClassifer.hpp>
1010

1111
#include <numeric>
12-
#include <execution>
1312

1413
using namespace et;
1514

@@ -320,6 +319,14 @@ TEST_CASE("Testing Tensor", "[Tensor]")
320319
CHECK((ones({4,4}) == t).any() == true);
321320
CHECK((ones({4,4}) == t).all() == false);
322321
}
322+
323+
SECTION("xtensor style views") {
324+
CHECK(view(t, 2).isSame(t.view({2})));
325+
326+
IndexList lst;
327+
lst.push_back(3);
328+
CHECK(dynamic_view(t, lst).isSame(t.view({3})));
329+
}
323330
}
324331

325332
SECTION("item") {
@@ -994,11 +1001,8 @@ TEST_CASE("Complex Tensor operations")
9941001
// Test summing along the first dimension. Making sure iterator and sum() works
9951002
// Tho you should always use the sum() function instead of accumulate or reduce
9961003
Tensor t = std::accumulate(a.begin(), a.end(), zeros({a.shape()[1]}));
997-
Tensor q = std::reduce(std::execution::par, a.begin(), a.end(), zeros({a.shape()[1]}));
9981004
Tensor a_sum = a.sum(0);
9991005
CHECK(t.isSame(a_sum));
1000-
CHECK(q.isSame(a_sum));
1001-
CHECK(t.isSame(q)); // Should be communicative
10021006
}
10031007

10041008
SECTION("generate") {

0 commit comments

Comments
 (0)