@@ -211,15 +211,15 @@ struct ETALER_EXPORT Tensor
211211 Tensor log () const { return backend ()->log (pimpl ()); }
212212 Tensor logical_not () const { return backend ()->logical_not (pimpl ()); }
213213
214- Tensor add (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->add (a (), b ()); }
215- Tensor subtract (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->subtract (a (), b ()); }
216- Tensor mul (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->mul (a (), b ()); }
217- Tensor div (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->div (a (), b ()); }
218- Tensor equal (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->equal (a (), b ()); }
219- Tensor greater (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->greater (a (), b ()); }
220- Tensor lesser (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->lesser (a (), b ()); }
221- Tensor logical_and (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->logical_and (a (), b ()); }
222- Tensor logical_or (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->logical_or (a (), b ()); }
214+ Tensor add (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->add (a. pimpl (), b. pimpl ()); }
215+ Tensor subtract (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->subtract (a. pimpl (), b. pimpl ()); }
216+ Tensor mul (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->mul (a. pimpl (), b. pimpl ()); }
217+ Tensor div (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->div (a. pimpl (), b. pimpl ()); }
218+ Tensor equal (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->equal (a. pimpl (), b. pimpl ()); }
219+ Tensor greater (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->greater (a. pimpl (), b. pimpl ()); }
220+ Tensor lesser (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->lesser (a. pimpl (), b. pimpl ()); }
221+ Tensor logical_and (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->logical_and (a. pimpl (), b. pimpl ()); }
222+ Tensor logical_or (const Tensor& other) const { auto [a, b] = brodcast (other); return backend ()->logical_or (a. pimpl (), b. pimpl ()); }
223223
224224 inline bool any () const { return cast (DType::Bool).sum (std::nullopt , DType::Bool).item <uint8_t >(); }
225225 inline bool all () const { return cast (DType::Bool).sum (std::nullopt ).item <int32_t >() == int32_t (size ()); }
@@ -252,14 +252,14 @@ struct ETALER_EXPORT Tensor
252252
253253 // Subscription operator
254254 Tensor operator [] (const IndexList& r) { return view (r); }
255+ template <typename ... Args>
256+ Tensor operator () (Args ... args) { return view ({args ...}); }
255257
256258 Tensor sum (std::optional<intmax_t > dim=std::nullopt , DType dtype=DType::Unknown) const ;
257259 Tensor abs () const { return backend ()->abs (pimpl ()); }
258260 bool isSame (const Tensor& other) const ;
259261
260262 // Utils
261- TensorImpl* operator () () {return pimpl ();}
262- const TensorImpl* operator () () const {return pimpl ();}
263263
264264 using iterator = TensorIterator<Tensor>;
265265 using const_iterator = TensorIterator<const Tensor>;
@@ -332,18 +332,18 @@ inline Tensor cellActivity(const Tensor& x, const Tensor& connections, const Ten
332332 return x;
333333 return x.cast (DType::Bool);
334334 }();
335- return x.backend ()->cellActivity (input (), connections (), permeances (), connected_permeance, active_threshold, has_unconnected_synapse);
335+ return x.backend ()->cellActivity (input. pimpl (), connections. pimpl (), permeances. pimpl (), connected_permeance, active_threshold, has_unconnected_synapse);
336336}
337337
338338inline void learnCorrilation (const Tensor& x, const Tensor& learn, const Tensor& connection
339339 , Tensor& permeances, float perm_inc, float perm_dec, bool has_unconnected_synapse=true )
340340{
341- x.backend ()->learnCorrilation (x (), learn (), connection (), permeances (), perm_inc, perm_dec, has_unconnected_synapse);
341+ x.backend ()->learnCorrilation (x. pimpl (), learn. pimpl (), connection. pimpl (), permeances. pimpl (), perm_inc, perm_dec, has_unconnected_synapse);
342342}
343343
344344inline Tensor globalInhibition (const Tensor& x, float fraction)
345345{
346- return x.backend ()->globalInhibition (x (), fraction);
346+ return x.backend ()->globalInhibition (x. pimpl (), fraction);
347347}
348348
349349Tensor inline cast (const Tensor& x, DType dtype)
@@ -358,27 +358,27 @@ inline Tensor copy(const Tensor& x)
358358
359359inline void sortSynapse (Tensor& connection, Tensor& permeances)
360360{
361- connection.backend ()->sortSynapse (connection (), permeances ());
361+ connection.backend ()->sortSynapse (connection. pimpl (), permeances. pimpl ());
362362}
363363
364364inline Tensor burst (const Tensor& x, const Tensor& s)
365365{
366- return x.backend ()->burst (x (), s ());
366+ return x.backend ()->burst (x. pimpl (), s. pimpl ());
367367}
368368
369369inline Tensor reverseBurst (const Tensor& x)
370370{
371- return x.backend ()->reverseBurst (x ());
371+ return x.backend ()->reverseBurst (x. pimpl ());
372372}
373373
374374inline void growSynapses (const Tensor& x, const Tensor& y, Tensor& connections, Tensor& permeances, float init_perm)
375375{
376- x.backend ()->growSynapses (x (), y (), connections (), permeances (), init_perm);
376+ x.backend ()->growSynapses (x. pimpl (), y. pimpl (), connections. pimpl (), permeances. pimpl (), init_perm);
377377}
378378
379379inline void decaySynapses (Tensor& connections, Tensor& permeances, float threshold)
380380{
381- connections.backend ()->decaySynapses (connections (), permeances (), threshold);
381+ connections.backend ()->decaySynapses (connections. pimpl (), permeances. pimpl (), threshold);
382382}
383383
384384inline void assign (Tensor& x, const Tensor& y)
@@ -420,6 +420,10 @@ inline Tensor logical_or(const Tensor& x1, const Tensor& x2) { return x1.logical
420420inline bool all (const Tensor& t) { return t.all (); }
421421inline bool any (const Tensor& t) { return t.any (); }
422422
423+ template <typename ... Args>
424+ inline Tensor view (const Tensor& t, Args... args) { return t.view ({args...}); }
425+ inline Tensor dynamic_view (const Tensor& t, const IndexList& indices) { return t.view (indices); }
426+
423427inline Tensor zeros_like (const Tensor& x) { return zeros (x.shape (), x.dtype (), x.backend ()); }
424428inline Tensor ones_like (const Tensor& x) { return ones (x.shape (), x.dtype (), x.backend ()); }
425429}
0 commit comments