33*/
44
55#include " sharpy/Creator.hpp"
6- #include " sharpy/NDArray.hpp"
76#include " sharpy/Deferred.hpp"
87#include " sharpy/Factory.hpp"
8+ #include " sharpy/NDArray.hpp"
99#include " sharpy/Transceiver.hpp"
1010#include " sharpy/TypeDispatch.hpp"
1111#include " sharpy/jit/mlir.hpp"
@@ -82,12 +82,11 @@ struct DeferredFull : public Deferred {
8282 const intptr_t *r_strides, uint64_t *lo_allocated,
8383 uint64_t *lo_aligned) {
8484 assert (rank == this ->rank ());
85- this ->set_value (std::move (
86- mk_tnsr (reinterpret_cast <Transceiver *>(this ->team ()), _dtype,
87- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes,
88- l_strides, o_allocated, o_aligned, o_offset, o_sizes,
89- o_strides, r_allocated, r_aligned, r_offset, r_sizes,
90- r_strides, lo_allocated, lo_aligned)));
85+ this ->set_value (std::move (mk_tnsr (
86+ this ->guid (), _dtype, this ->shape (), this ->device (), this ->team (),
87+ l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated,
88+ o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
89+ r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
9190 });
9291 return false ;
9392 }
@@ -102,8 +101,8 @@ struct DeferredFull : public Deferred {
102101};
103102
104103FutureArray *Creator::full (const shape_type &shape, const py::object &val,
105- DTypeId dtype, const std::string &device,
106- uint64_t team) {
104+ DTypeId dtype, const std::string &device,
105+ uint64_t team) {
107106 auto v = mk_scalar (val, dtype);
108107 return new FutureArray (
109108 defer<DeferredFull>(shape, v, dtype, device, mkTeam (team)));
@@ -132,26 +131,26 @@ struct DeferredArange : public Deferred {
132131 auto dtyp = jit::getPTDType (dtype ());
133132 auto envs = jit::mkEnvs (builder, rank (), _device, team ());
134133
135- dm.addVal (this -> guid (),
136- builder. create <::imex::ndarray::LinSpaceOp>(loc, start, stop, num ,
137- false , dtyp, envs) ,
138- [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
139- intptr_t l_offset, const intptr_t *l_sizes ,
140- const intptr_t *l_strides, void *o_allocated ,
141- void *o_aligned, intptr_t o_offset ,
142- const intptr_t *o_sizes , const intptr_t *o_strides ,
143- void *r_allocated, void *r_aligned, intptr_t r_offset ,
144- const intptr_t *r_sizes , const intptr_t *r_strides ,
145- uint64_t *lo_allocated , uint64_t *lo_aligned) {
146- assert (rank == 1 );
147- assert (o_strides[ 0 ] == 1 );
148- this -> set_value ( std::move ( mk_tnsr (
149- reinterpret_cast <Transceiver *>( this ->team ()), _dtype,
150- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes ,
151- l_strides, o_allocated, o_aligned, o_offset, o_sizes ,
152- o_strides, r_allocated, r_aligned, r_offset, r_sizes ,
153- r_strides, lo_allocated, lo_aligned)));
154- });
134+ dm.addVal (
135+ this -> guid () ,
136+ builder. create <::imex::ndarray::LinSpaceOp>(loc, start, stop, num ,
137+ false , dtyp, envs) ,
138+ [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
139+ intptr_t l_offset, const intptr_t *l_sizes ,
140+ const intptr_t *l_strides, void *o_allocated, void *o_aligned ,
141+ intptr_t o_offset , const intptr_t *o_sizes ,
142+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
143+ intptr_t r_offset , const intptr_t *r_sizes ,
144+ const intptr_t *r_strides , uint64_t *lo_allocated,
145+ uint64_t *lo_aligned) {
146+ assert (rank == 1 );
147+ assert (o_strides[ 0 ] == 1 );
148+ this ->set_value ( std::move ( mk_tnsr (
149+ this ->guid (), _dtype, this -> shape (), this -> device (), this -> team () ,
150+ l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated ,
151+ o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
152+ r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
153+ });
155154 return false ;
156155 }
157156
@@ -165,8 +164,8 @@ struct DeferredArange : public Deferred {
165164};
166165
167166FutureArray *Creator::arange (uint64_t start, uint64_t end, uint64_t step,
168- DTypeId dtype, const std::string &device,
169- uint64_t team) {
167+ DTypeId dtype, const std::string &device,
168+ uint64_t team) {
170169 return new FutureArray (
171170 defer<DeferredArange>(start, end, step, dtype, device, mkTeam (team)));
172171}
@@ -193,26 +192,26 @@ struct DeferredLinspace : public Deferred {
193192 auto dtyp = jit::getPTDType (dtype ());
194193 auto envs = jit::mkEnvs (builder, rank (), _device, team ());
195194
196- dm.addVal (this -> guid (),
197- builder. create <::imex::ndarray::LinSpaceOp>(
198- loc, start, stop, num, _endpoint, dtyp, envs) ,
199- [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
200- intptr_t l_offset, const intptr_t *l_sizes ,
201- const intptr_t *l_strides, void *o_allocated ,
202- void *o_aligned, intptr_t o_offset ,
203- const intptr_t *o_sizes , const intptr_t *o_strides ,
204- void *r_allocated, void *r_aligned, intptr_t r_offset ,
205- const intptr_t *r_sizes , const intptr_t *r_strides ,
206- uint64_t *lo_allocated , uint64_t *lo_aligned) {
207- assert (rank == 1 );
208- assert (l_strides[ 0 ] == 1 );
209- this -> set_value ( std::move ( mk_tnsr (
210- reinterpret_cast <Transceiver *>( this ->team ()), _dtype,
211- this ->shape (), l_allocated, l_aligned, l_offset, l_sizes ,
212- l_strides, o_allocated, o_aligned, o_offset, o_sizes ,
213- o_strides, r_allocated, r_aligned, r_offset, r_sizes ,
214- r_strides, lo_allocated, lo_aligned)));
215- });
195+ dm.addVal (
196+ this -> guid (),
197+ builder. create <::imex::ndarray::LinSpaceOp>( loc, start, stop, num,
198+ _endpoint, dtyp, envs) ,
199+ [ this ]( uint64_t rank, void *l_allocated, void *l_aligned ,
200+ intptr_t l_offset, const intptr_t *l_sizes ,
201+ const intptr_t *l_strides, void *o_allocated, void *o_aligned ,
202+ intptr_t o_offset , const intptr_t *o_sizes ,
203+ const intptr_t *o_strides, void *r_allocated, void *r_aligned,
204+ intptr_t r_offset , const intptr_t *r_sizes ,
205+ const intptr_t *r_strides , uint64_t *lo_allocated,
206+ uint64_t *lo_aligned) {
207+ assert (rank == 1 );
208+ assert (l_strides[ 0 ] == 1 );
209+ this ->set_value ( std::move ( mk_tnsr (
210+ this ->guid (), _dtype, this -> shape (), this -> device (), this -> team () ,
211+ l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated ,
212+ o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
213+ r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
214+ });
216215 return false ;
217216 }
218217
@@ -227,10 +226,10 @@ struct DeferredLinspace : public Deferred {
227226};
228227
229228FutureArray *Creator::linspace (double start, double end, uint64_t num,
230- bool endpoint, DTypeId dtype,
231- const std::string &device, uint64_t team) {
232- return new FutureArray (defer<DeferredLinspace>(start, end, num, endpoint, dtype,
233- device, mkTeam (team)));
229+ bool endpoint, DTypeId dtype,
230+ const std::string &device, uint64_t team) {
231+ return new FutureArray (defer<DeferredLinspace>(start, end, num, endpoint,
232+ dtype, device, mkTeam (team)));
234233}
235234
236235// ***************************************************************************
@@ -239,8 +238,9 @@ extern DTypeId DEFAULT_FLOAT;
239238extern DTypeId DEFAULT_INT;
240239
241240std::pair<FutureArray *, bool > Creator::mk_future (const py::object &b,
242- const std::string &device,
243- uint64_t team, DTypeId dtype) {
241+ const std::string &device,
242+ uint64_t team,
243+ DTypeId dtype) {
244244 if (py::isinstance<FutureArray>(b)) {
245245 return {b.cast <FutureArray *>(), false };
246246 } else if (py::isinstance<py::float_>(b) || py::isinstance<py::int_>(b)) {
0 commit comments