Skip to content

Commit 958af35

Browse files
committed
support () for indexing
1 parent aa75d76 commit 958af35

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

Etaler/Core/Tensor.hpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -211,15 +211,15 @@ struct ETALER_EXPORT Tensor
211211
Tensor log() const { return backend()->log(pimpl()); }
212212
Tensor logical_not() const { return backend()->logical_not(pimpl()); }
213213

214-
Tensor add(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->add(a(), b()); }
215-
Tensor subtract(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->subtract(a(), b()); }
216-
Tensor mul(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->mul(a(), b()); }
217-
Tensor div(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->div(a(), b()); }
218-
Tensor equal(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->equal(a(), b()); }
219-
Tensor greater(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->greater(a(), b()); }
220-
Tensor lesser(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->lesser(a(), b()); }
221-
Tensor logical_and(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_and(a(), b()); }
222-
Tensor logical_or(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_or(a(), b()); }
214+
Tensor add(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->add(a.pimpl(), b.pimpl()); }
215+
Tensor subtract(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->subtract(a.pimpl(), b.pimpl()); }
216+
Tensor mul(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->mul(a.pimpl(), b.pimpl()); }
217+
Tensor div(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->div(a.pimpl(), b.pimpl()); }
218+
Tensor equal(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->equal(a.pimpl(), b.pimpl()); }
219+
Tensor greater(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->greater(a.pimpl(), b.pimpl()); }
220+
Tensor lesser(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->lesser(a.pimpl(), b.pimpl()); }
221+
Tensor logical_and(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_and(a.pimpl(), b.pimpl()); }
222+
Tensor logical_or(const Tensor& other) const { auto [a, b] = brodcast(other); return backend()->logical_or(a.pimpl(), b.pimpl()); }
223223

224224
inline bool any() const { return cast(DType::Bool).sum(std::nullopt, DType::Bool).item<uint8_t>(); }
225225
inline bool all() const { return cast(DType::Bool).sum(std::nullopt).item<int32_t>() == int32_t(size()); }
@@ -252,14 +252,14 @@ struct ETALER_EXPORT Tensor
252252

253253
//Subscription operator
254254
Tensor operator [] (const IndexList& r) { return view(r); }
255+
template <typename ... Args>
256+
Tensor operator () (Args ... args) { return view({args ...}); }
255257

256258
Tensor sum(std::optional<intmax_t> dim=std::nullopt, DType dtype=DType::Unknown) const;
257259
Tensor abs() const { return backend()->abs(pimpl()); }
258260
bool isSame (const Tensor& other) const;
259261

260262
//Utils
261-
TensorImpl* operator () () {return pimpl();}
262-
const TensorImpl* operator () () const {return pimpl();}
263263

264264
using iterator = TensorIterator<Tensor>;
265265
using const_iterator = TensorIterator<const Tensor>;
@@ -332,18 +332,18 @@ inline Tensor cellActivity(const Tensor& x, const Tensor& connections, const Ten
332332
return x;
333333
return x.cast(DType::Bool);
334334
}();
335-
return x.backend()->cellActivity(input(), connections(), permeances(), connected_permeance, active_threshold, has_unconnected_synapse);
335+
return x.backend()->cellActivity(input.pimpl(), connections.pimpl(), permeances.pimpl(), connected_permeance, active_threshold, has_unconnected_synapse);
336336
}
337337

338338
inline void learnCorrilation(const Tensor& x, const Tensor& learn, const Tensor& connection
339339
, Tensor& permeances, float perm_inc, float perm_dec, bool has_unconnected_synapse=true)
340340
{
341-
x.backend()->learnCorrilation(x(), learn(), connection(), permeances(), perm_inc, perm_dec, has_unconnected_synapse);
341+
x.backend()->learnCorrilation(x.pimpl(), learn.pimpl(), connection.pimpl(), permeances.pimpl(), perm_inc, perm_dec, has_unconnected_synapse);
342342
}
343343

344344
inline Tensor globalInhibition(const Tensor& x, float fraction)
345345
{
346-
return x.backend()->globalInhibition(x(), fraction);
346+
return x.backend()->globalInhibition(x.pimpl(), fraction);
347347
}
348348

349349
Tensor inline cast(const Tensor& x, DType dtype)
@@ -358,27 +358,27 @@ inline Tensor copy(const Tensor& x)
358358

359359
inline void sortSynapse(Tensor& connection, Tensor& permeances)
360360
{
361-
connection.backend()->sortSynapse(connection(), permeances());
361+
connection.backend()->sortSynapse(connection.pimpl(), permeances.pimpl());
362362
}
363363

364364
inline Tensor burst(const Tensor& x, const Tensor& s)
365365
{
366-
return x.backend()->burst(x(), s());
366+
return x.backend()->burst(x.pimpl(), s.pimpl());
367367
}
368368

369369
inline Tensor reverseBurst(const Tensor& x)
370370
{
371-
return x.backend()->reverseBurst(x());
371+
return x.backend()->reverseBurst(x.pimpl());
372372
}
373373

374374
inline void growSynapses(const Tensor& x, const Tensor& y, Tensor& connections, Tensor& permeances, float init_perm)
375375
{
376-
x.backend()->growSynapses(x(), y(), connections(), permeances(), init_perm);
376+
x.backend()->growSynapses(x.pimpl(), y.pimpl(), connections.pimpl(), permeances.pimpl(), init_perm);
377377
}
378378

379379
inline void decaySynapses(Tensor& connections, Tensor& permeances, float threshold)
380380
{
381-
connections.backend()->decaySynapses(connections(), permeances(), threshold);
381+
connections.backend()->decaySynapses(connections.pimpl(), permeances.pimpl(), threshold);
382382
}
383383

384384
inline void assign(Tensor& x, const Tensor& y)

0 commit comments

Comments
 (0)