@@ -190,8 +190,10 @@ static ::imex::ndarray::EWUnyOpId sharpy(const EWUnyOpId uop) {
190190 return ::imex::ndarray::LOGICAL_NOT;
191191 case __NEG__:
192192 case NEGATIVE:
193+ return ::imex::ndarray::NEGATIVE;
193194 case __POS__:
194195 case POSITIVE:
196+ return ::imex::ndarray::POSITIVE;
195197 default :
196198 throw std::runtime_error (" Unknown/invalid elementwise unary operation" );
197199 }
@@ -213,24 +215,31 @@ struct DeferredEWUnyOp : public Deferred {
213215 auto aTyp = av.getType ().cast <::imex::ndarray::NDArrayType>();
214216 auto outTyp = aTyp.cloneWith (shape (), aTyp.getElementType ());
215217
218+ auto ndOpId = sharpy (_op);
216219 auto uop = builder.create <::imex::ndarray::EWUnyOp>(
217- loc, outTyp, builder.getI32IntegerAttr (sharpy (_op)), av);
220+ loc, outTyp, builder.getI32IntegerAttr (ndOpId), av);
221+ // positive op will be eliminated so it is equivalent to a view
222+ auto view = ndOpId == ::imex::ndarray::POSITIVE;
218223
219224 dm.addVal (
220225 this ->guid (), uop,
221- [this ](uint64_t rank, void *l_allocated, void *l_aligned,
222- intptr_t l_offset, const intptr_t *l_sizes,
223- const intptr_t *l_strides, void *o_allocated, void *o_aligned,
224- intptr_t o_offset, const intptr_t *o_sizes,
225- const intptr_t *o_strides, void *r_allocated, void *r_aligned,
226- intptr_t r_offset, const intptr_t *r_sizes,
227- const intptr_t *r_strides, uint64_t *lo_allocated,
228- uint64_t *lo_aligned) {
229- this ->set_value (std::move (mk_tnsr (
230- this ->guid (), _dtype, this ->shape (), this ->device (), this ->team (),
231- l_allocated, l_aligned, l_offset, l_sizes, l_strides, o_allocated,
232- o_aligned, o_offset, o_sizes, o_strides, r_allocated, r_aligned,
233- r_offset, r_sizes, r_strides, lo_allocated, lo_aligned)));
226+ [this , view](uint64_t rank, void *l_allocated, void *l_aligned,
227+ intptr_t l_offset, const intptr_t *l_sizes,
228+ const intptr_t *l_strides, void *o_allocated,
229+ void *o_aligned, intptr_t o_offset,
230+ const intptr_t *o_sizes, const intptr_t *o_strides,
231+ void *r_allocated, void *r_aligned, intptr_t r_offset,
232+ const intptr_t *r_sizes, const intptr_t *r_strides,
233+ uint64_t *lo_allocated, uint64_t *lo_aligned) {
234+ auto t = mk_tnsr (this ->guid (), _dtype, this ->shape (), this ->device (),
235+ this ->team (), l_allocated, l_aligned, l_offset,
236+ l_sizes, l_strides, o_allocated, o_aligned, o_offset,
237+ o_sizes, o_strides, r_allocated, r_aligned, r_offset,
238+ r_sizes, r_strides, lo_allocated, lo_aligned);
239+ if (view && Registry::has (_a)) {
240+ t->set_base (Registry::get (_a).get ());
241+ }
242+ this ->set_value (std::move (t));
234243 });
235244 return false ;
236245 }
0 commit comments