Skip to content

Commit 6d199f1

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Allow overriding metadata when creating a view.
Summary: . Differential Revision: D86070964
1 parent 5d04bd5 commit 6d199f1

File tree

4 files changed

+130
-10
lines changed

4 files changed

+130
-10
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,29 @@ public extension AnyTensor {
130130
/// The total number of elements in the tensor.
131131
var count: Int { __count }
132132

133+
/// Creates a new tensor that shares the underlying data storage with the
134+
/// given tensor, with metadata overrides. An empty array for
135+
/// a parameter signifies that it should be inherited or derived.
136+
///
137+
/// - Parameters:
138+
/// - tensor: The tensor instance to create a view of.
139+
/// - shape: An override for the tensor's shape.
140+
/// - dimensionOrder: An override for the tensor's dimension order.
141+
/// - strides: An override for the tensor's strides.
142+
convenience init(
143+
_ tensor: AnyTensor,
144+
shape: [Int] = [],
145+
dimensionOrder: [Int] = [],
146+
strides: [Int] = []
147+
) {
148+
self.init(
149+
__tensor: tensor,
150+
shape: shape.map(NSNumber.init),
151+
dimensionOrder: dimensionOrder.map(NSNumber.init),
152+
strides: strides.map(NSNumber.init)
153+
)
154+
}
155+
133156
/// Initializes a tensor without copying the provided data.
134157
///
135158
/// - Parameters:
@@ -234,8 +257,7 @@ public extension AnyTensor {
234257

235258
/// Attempts to convert this type-erased `AnyTensor` into a strongly-typed `Tensor<T>`.
236259
///
237-
/// - Returns: An `AnyTensor` if `self.dataType == T.dataType`,
238-
/// otherwise `nil` when the runtime dtype doesn’t match.
260+
/// - Returns: A `Tensor<T>` if the runtime data type matches, otherwise `nil`.
239261
func asTensor<T: Scalar>() -> Tensor<T>? {
240262
guard dataType == T.dataType else { return nil }
241263
return Tensor<T>(self)
@@ -586,11 +608,28 @@ public final class Tensor<T: Scalar>: Equatable {
586608
}
587609

588610
/// Creates a new tensor that shares the underlying data storage with the
589-
/// given tensor. This new tensor is a view and does not own the data.
611+
/// given tensor, with optional metadata overrides. An empty array for
612+
/// a parameter signifies that it should be inherited or derived.
590613
///
591-
/// - Parameter tensor: The tensor to create a view of.
592-
public convenience init(_ tensor: Tensor<T>) {
593-
self.init(AnyTensor(tensor.anyTensor))
614+
/// - Parameters:
615+
/// - tensor: The tensor to create a view of.
616+
/// - shape: An override for the tensor's shape.
617+
/// - dimensionOrder: An override for the tensor's dimension order.
618+
/// - strides: An override for the tensor's strides.
619+
public convenience init(
620+
_ tensor: Tensor<T>,
621+
shape: [Int] = [],
622+
dimensionOrder: [Int] = [],
623+
strides: [Int] = []
624+
) {
625+
self.init(
626+
AnyTensor(
627+
tensor.anyTensor,
628+
shape: shape,
629+
dimensionOrder: dimensionOrder,
630+
strides: strides
631+
)
632+
)
594633
}
595634

596635
/// Initializes a tensor without copying the provided data.
@@ -742,7 +781,6 @@ public final class Tensor<T: Scalar>: Equatable {
742781
}
743782

744783
/// Returns a copy of the tensor, converted to the specified scalar type.
745-
///
746784
/// - Parameter dataType: The target scalar type.
747785
/// - Returns: A new tensor with the same shape and metadata but converted elements.
748786
public func copy<U: Scalar>(to dataType: U.Type) -> Tensor<U> {

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,35 @@ __attribute__((objc_subclassing_restricted))
150150
NS_DESIGNATED_INITIALIZER
151151
NS_SWIFT_UNAVAILABLE("");
152152

153+
/**
154+
* Creates a new tensor that shares the underlying data storage with the
155+
* given tensor, with metadata overrides. An empty array for
156+
* a parameter signifies that it should be inherited or derived.
157+
*
158+
* @param otherTensor The tensor instance to create a view of.
159+
* @param shape An override for the tensor's shape.
160+
* @param dimensionOrder An override for the tensor's dimension order.
161+
* @param strides An override for the tensor's strides.
162+
* @return A new ExecuTorchTensor instance that shares data with otherTensor.
163+
*/
164+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
165+
shape:(NSArray<NSNumber *> *)shape
166+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
167+
strides:(NSArray<NSNumber *> *)strides
168+
NS_REFINED_FOR_SWIFT;
169+
170+
/**
171+
* Creates a new tensor that shares the underlying data storage with the
172+
* given tensor, with an overridden shape.
173+
*
174+
* @param otherTensor The tensor instance to create a view of.
175+
* @param shape An override for the tensor's shape.
176+
* @return A new ExecuTorchTensor instance that shares data with otherTensor.
177+
*/
178+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
179+
shape:(NSArray<NSNumber *> *)shape
180+
NS_SWIFT_UNAVAILABLE("");
181+
153182
/**
154183
* Creates a new tensor that shares the underlying data storage with the
155184
* given tensor. This new tensor is a view and does not own the data.
@@ -158,7 +187,7 @@ __attribute__((objc_subclassing_restricted))
158187
* @return A new ExecuTorchTensor instance that shares data with otherTensor.
159188
*/
160189
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
161-
NS_SWIFT_NAME(init(_:));
190+
NS_SWIFT_UNAVAILABLE("");
162191

163192
/**
164193
* Creates a deep copy of the tensor.

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,16 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance {
126126
return self;
127127
}
128128

129-
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
129+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
130+
shape:(NSArray<NSNumber *> *)shape
131+
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
132+
strides:(NSArray<NSNumber *> *)strides {
130133
ET_CHECK(otherTensor);
131134
auto tensor = make_tensor_ptr(
132-
*reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance)
135+
*reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance),
136+
utils::toVector<SizesType>(shape),
137+
utils::toVector<DimOrderType>(dimensionOrder),
138+
utils::toVector<StridesType>(strides)
133139
);
134140
self = [self initWithNativeInstance:&tensor];
135141
if (self) {
@@ -138,6 +144,21 @@ - (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
138144
return self;
139145
}
140146

147+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
148+
shape:(NSArray<NSNumber *> *)shape {
149+
return [self initWithTensor:otherTensor
150+
shape:shape
151+
dimensionOrder:@[]
152+
strides:@[]];
153+
}
154+
155+
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
156+
return [self initWithTensor:otherTensor
157+
shape:@[]
158+
dimensionOrder:@[]
159+
strides:@[]];
160+
}
161+
141162
- (instancetype)copy {
142163
return [self copyWithZone:nil];
143164
}

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,38 @@ class TensorTest: XCTestCase {
168168
XCTAssertEqual(tensor2.count, tensor1.count)
169169
}
170170

171+
func testInitWithTensorDerivesStridesAndSharesStorage() {
172+
var scalars: [Int32] = [1, 2, 3, 4, 5, 6]
173+
let tensor1 = scalars.withUnsafeMutableBytes {
174+
Tensor<Int32>(bytesNoCopy: $0.baseAddress!, shape: [2, 3])
175+
}
176+
let tensor2 = Tensor(tensor1, shape: [3, 2])
177+
178+
XCTAssertEqual(tensor1.withUnsafeBytes { $0.baseAddress }, tensor2.withUnsafeBytes { $0.baseAddress })
179+
XCTAssertEqual(tensor2.shape, [3, 2])
180+
XCTAssertEqual(tensor2.strides, [2, 1])
181+
XCTAssertEqual(tensor2.dimensionOrder, [0, 1])
182+
183+
scalars[0] = 99
184+
XCTAssertEqual(tensor2.withUnsafeBytes { $0[0] }, 99)
185+
}
186+
187+
func testInitWithTensorExplicitOverridesAppliesMetadata() {
188+
var scalars: [Float] = [1, 2, 3, 4]
189+
let tensor1 = scalars.withUnsafeMutableBytes {
190+
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
191+
}
192+
let tensor2 = Tensor(tensor1, shape: [2, 2], dimensionOrder: [1, 0], strides: [1, 2])
193+
194+
XCTAssertEqual(tensor1.withUnsafeBytes { $0.baseAddress }, tensor2.withUnsafeBytes { $0.baseAddress })
195+
XCTAssertEqual(tensor2.shape, [2, 2])
196+
XCTAssertEqual(tensor2.dimensionOrder, [1, 0])
197+
XCTAssertEqual(tensor2.strides, [1, 2])
198+
199+
scalars[3] = 42
200+
XCTAssertEqual(tensor2.withUnsafeBytes { $0[3] }, 42)
201+
}
202+
171203
func testCopy() {
172204
var data: [Double] = [10.0, 20.0, 30.0, 40.0]
173205
let tensor1 = data.withUnsafeMutableBytes {

0 commit comments

Comments
 (0)