Skip to content

Commit 480a3b5

Browse files
authored
Simplify Internal Tensor Creation (#159)
* add private constructor taking pytorch tensor so can simplify tensor instantiation within tensor functions
1 parent c24f85d commit 480a3b5

File tree

2 files changed

+57
-130
lines changed

2 files changed

+57
-130
lines changed

nuTens/tensors/tensor.hpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,7 @@ class Tensor
435435

436436
NT_PROFILE();
437437

438-
Tensor ret;
439-
ret.setTensor(tensor);
440-
441-
return ret;
438+
return Tensor(tensor);
442439
}
443440

444441
protected:
@@ -496,8 +493,21 @@ class Tensor
496493
return indicesVec;
497494
}
498495

496+
private:
497+
498+
/// Construct a nuTens tensor directly from a pytorch tensor
499+
Tensor(const torch::Tensor &tensor)
500+
:
501+
_tensor(tensor),
502+
_dType(dtypes::invScalarTypeMap(tensor.scalar_type())),
503+
_device(dtypes::invDeviceTypeMap(tensor.device().type()))
504+
{
505+
NT_PROFILE();
506+
}
507+
499508
protected:
500509
torch::Tensor _tensor;
510+
501511
#endif
502512
};
503513

@@ -528,6 +538,8 @@ template <typename Tdtype, int TnDims, dtypes::deviceType Tdevice> class Accesse
528538
AccessedTensor(const torch::Tensor &tensor)
529539
: _packedAccessor(tensor.packed_accessor32<Tdtype, TnDims>()), _accessor(tensor.accessor<Tdtype, TnDims>())
530540
{
541+
NT_PROFILE();
542+
531543
setTensor(tensor);
532544
};
533545

0 commit comments

Comments
 (0)