Skip to content

Commit 7030b33

Browse files
authored
Tensor constructor to create with a raw pointer by copying the data.
Differential Revision: D71904351 Pull Request resolved: #9674
1 parent 1ea101e commit 7030b33

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,4 +222,28 @@ __attribute__((deprecated("This API is experimental.")))
222222

223223
@end
224224

225+
#pragma mark - Bytes Category
226+
227+
@interface ExecuTorchTensor (Bytes)
228+
229+
/**
230+
* Initializes a tensor by copying bytes from the provided pointer.
231+
*
232+
* @param pointer A pointer to the source data buffer.
233+
* @param shape An NSArray of NSNumber objects representing the tensor's shape.
234+
* @param strides An NSArray of NSNumber objects representing the tensor's strides.
235+
* @param dimensionOrder An NSArray of NSNumber objects indicating the order of dimensions.
236+
* @param dataType An ExecuTorchDataType value specifying the element type.
237+
* @param shapeDynamism An ExecuTorchShapeDynamism value indicating the shape dynamism.
238+
* @return An initialized ExecuTorchTensor instance with its own copy of the data.
239+
*/
240+
- (instancetype)initWithBytes:(const void *)pointer
241+
shape:(NSArray<NSNumber *> *)shape
242+
strides:(NSArray<NSNumber *> *)strides
243+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
244+
dataType:(ExecuTorchDataType)dataType
245+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism;
246+
247+
@end
248+
225249
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,28 @@ - (instancetype)initWithBytesNoCopy:(void *)pointer
140140
}
141141

142142
@end
143+
144+
@implementation ExecuTorchTensor (Bytes)
145+
146+
- (instancetype)initWithBytes:(const void *)pointer
147+
shape:(NSArray<NSNumber *> *)shape
148+
strides:(NSArray<NSNumber *> *)strides
149+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
150+
dataType:(ExecuTorchDataType)dataType
151+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
152+
ET_CHECK(pointer);
153+
const auto size = ExecuTorchElementCountOfShape(shape) * ExecuTorchSizeOfDataType(dataType);
154+
std::vector<uint8_t> data(static_cast<const uint8_t *>(pointer),
155+
static_cast<const uint8_t *>(pointer) + size);
156+
auto tensor = make_tensor_ptr(
157+
utils::toVector<SizesType>(shape),
158+
std::move(data),
159+
utils::toVector<DimOrderType>(dimensionOrder),
160+
utils::toVector<StridesType>(strides),
161+
static_cast<ScalarType>(dataType),
162+
static_cast<TensorShapeDynamism>(shapeDynamism)
163+
);
164+
return [self initWithNativeInstance:&tensor];
165+
}
166+
167+
@end

0 commit comments

Comments
 (0)