Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,29 @@ public extension AnyTensor {
/// The total number of elements in the tensor.
var count: Int { __count }

/// Creates a new tensor that shares the underlying data storage with the
/// given tensor, with metadata overrides. An empty array for
/// a parameter signifies that it should be inherited or derived.
///
/// - Parameters:
/// - tensor: The tensor instance to create a view of.
/// - shape: An override for the tensor's shape.
/// - dimensionOrder: An override for the tensor's dimension order.
/// - strides: An override for the tensor's strides.
convenience init(
_ tensor: AnyTensor,
shape: [Int] = [],
dimensionOrder: [Int] = [],
strides: [Int] = []
) {
self.init(
__tensor: tensor,
shape: shape.map(NSNumber.init),
dimensionOrder: dimensionOrder.map(NSNumber.init),
strides: strides.map(NSNumber.init)
)
}

/// Initializes a tensor without copying the provided data.
///
/// - Parameters:
Expand Down Expand Up @@ -234,8 +257,7 @@ public extension AnyTensor {

/// Attempts to convert this type-erased `AnyTensor` into a strongly-typed `Tensor<T>`.
///
/// - Returns: An `AnyTensor` if `self.dataType == T.dataType`,
/// otherwise `nil` when the runtime dtype doesn’t match.
/// - Returns: A `Tensor<T>` if the runtime data type matches, otherwise `nil`.
func asTensor<T: Scalar>() -> Tensor<T>? {
guard dataType == T.dataType else { return nil }
return Tensor<T>(self)
Expand Down Expand Up @@ -586,11 +608,28 @@ public final class Tensor<T: Scalar>: Equatable {
}

/// Creates a new tensor that shares the underlying data storage with the
/// given tensor. This new tensor is a view and does not own the data.
/// given tensor, with optional metadata overrides. An empty array for
/// a parameter signifies that it should be inherited or derived.
///
/// - Parameter tensor: The tensor to create a view of.
public convenience init(_ tensor: Tensor<T>) {
self.init(AnyTensor(tensor.anyTensor))
/// - Parameters:
/// - tensor: The tensor to create a view of.
/// - shape: An override for the tensor's shape.
/// - dimensionOrder: An override for the tensor's dimension order.
/// - strides: An override for the tensor's strides.
public convenience init(
_ tensor: Tensor<T>,
shape: [Int] = [],
dimensionOrder: [Int] = [],
strides: [Int] = []
) {
self.init(
AnyTensor(
tensor.anyTensor,
shape: shape,
dimensionOrder: dimensionOrder,
strides: strides
)
)
}

/// Initializes a tensor without copying the provided data.
Expand Down
31 changes: 30 additions & 1 deletion extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,35 @@ __attribute__((objc_subclassing_restricted))
NS_DESIGNATED_INITIALIZER
NS_SWIFT_UNAVAILABLE("");

/**
* Creates a new tensor that shares the underlying data storage with the
* given tensor, with metadata overrides. An empty array for
* a parameter signifies that it should be inherited or derived.
*
* @param otherTensor The tensor instance to create a view of.
* @param shape An override for the tensor's shape.
* @param dimensionOrder An override for the tensor's dimension order.
* @param strides An override for the tensor's strides.
* @return A new ExecuTorchTensor instance that shares data with otherTensor.
*/
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
shape:(NSArray<NSNumber *> *)shape
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
strides:(NSArray<NSNumber *> *)strides
NS_REFINED_FOR_SWIFT;

/**
* Creates a new tensor that shares the underlying data storage with the
* given tensor, with an overridden shape.
*
* @param otherTensor The tensor instance to create a view of.
* @param shape An override for the tensor's shape.
* @return A new ExecuTorchTensor instance that shares data with otherTensor.
*/
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
shape:(NSArray<NSNumber *> *)shape
NS_SWIFT_UNAVAILABLE("");

/**
* Creates a new tensor that shares the underlying data storage with the
* given tensor. This new tensor is a view and does not own the data.
Expand All @@ -158,7 +187,7 @@ __attribute__((objc_subclassing_restricted))
* @return A new ExecuTorchTensor instance that shares data with otherTensor.
*/
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
NS_SWIFT_NAME(init(_:));
NS_SWIFT_UNAVAILABLE("");

/**
* Creates a deep copy of the tensor.
Expand Down
25 changes: 23 additions & 2 deletions extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,16 @@ - (instancetype)initWithNativeInstance:(void *)nativeInstance {
return self;
}

- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
shape:(NSArray<NSNumber *> *)shape
dimensionOrder:(NSArray<NSNumber *> *)dimensionOrder
strides:(NSArray<NSNumber *> *)strides {
ET_CHECK(otherTensor);
auto tensor = make_tensor_ptr(
*reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance)
*reinterpret_cast<TensorPtr *>(otherTensor.nativeInstance),
utils::toVector<SizesType>(shape),
utils::toVector<DimOrderType>(dimensionOrder),
utils::toVector<StridesType>(strides)
);
self = [self initWithNativeInstance:&tensor];
if (self) {
Expand All @@ -138,6 +144,21 @@ - (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
return self;
}

- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor
shape:(NSArray<NSNumber *> *)shape {
return [self initWithTensor:otherTensor
shape:shape
dimensionOrder:@[]
strides:@[]];
}

- (instancetype)initWithTensor:(ExecuTorchTensor *)otherTensor {
return [self initWithTensor:otherTensor
shape:@[]
dimensionOrder:@[]
strides:@[]];
}

- (instancetype)copy {
return [self copyWithZone:nil];
}
Expand Down
32 changes: 32 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,38 @@ class TensorTest: XCTestCase {
XCTAssertEqual(tensor2.count, tensor1.count)
}

func testInitWithTensorDerivesStridesAndSharesStorage() {
var scalars: [Int32] = [1, 2, 3, 4, 5, 6]
let tensor1 = scalars.withUnsafeMutableBytes {
Tensor<Int32>(bytesNoCopy: $0.baseAddress!, shape: [2, 3])
}
let tensor2 = Tensor(tensor1, shape: [3, 2])

XCTAssertEqual(tensor1.withUnsafeBytes { $0.baseAddress }, tensor2.withUnsafeBytes { $0.baseAddress })
XCTAssertEqual(tensor2.shape, [3, 2])
XCTAssertEqual(tensor2.strides, [2, 1])
XCTAssertEqual(tensor2.dimensionOrder, [0, 1])

scalars[0] = 99
XCTAssertEqual(tensor2.withUnsafeBytes { $0[0] }, 99)
}

func testInitWithTensorExplicitOverridesAppliesMetadata() {
var scalars: [Float] = [1, 2, 3, 4]
let tensor1 = scalars.withUnsafeMutableBytes {
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
}
let tensor2 = Tensor(tensor1, shape: [2, 2], dimensionOrder: [1, 0], strides: [1, 2])

XCTAssertEqual(tensor1.withUnsafeBytes { $0.baseAddress }, tensor2.withUnsafeBytes { $0.baseAddress })
XCTAssertEqual(tensor2.shape, [2, 2])
XCTAssertEqual(tensor2.dimensionOrder, [1, 0])
XCTAssertEqual(tensor2.strides, [1, 2])

scalars[3] = 42
XCTAssertEqual(tensor2.withUnsafeBytes { $0[3] }, 42)
}

func testCopy() {
var data: [Double] = [10.0, 20.0, 30.0, 40.0]
let tensor1 = data.withUnsafeMutableBytes {
Expand Down
Loading