Skip to content

Commit 1fdd19a

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Init Tensor with an array of scalar by sharing the memory.
Summary: . Differential Revision: D77274533
1 parent c70047f commit 1fdd19a

File tree

2 files changed

+99
-24
lines changed

2 files changed

+99
-24
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,35 @@ public class Tensor<T: Scalar>: Equatable {
693693
))
694694
}
695695

696+
/// Initializes a tensor without copying the data from an existing array.
697+
///
698+
/// - Parameters:
699+
/// - scalars: An `inout` array of scalar values to share memory with.
700+
/// - shape: An array of integers representing the desired tensor shape. If empty, the shape is inferred as `[scalars.count]`.
701+
/// - strides: An array of integers representing the tensor strides.
702+
/// - dimensionOrder: An array of integers indicating the order of dimensions.
703+
/// - shapeDynamism: A `ShapeDynamism` value indicating the shape dynamism.
704+
public convenience init(
705+
_ scalars: inout [T],
706+
shape: [Int] = [],
707+
strides: [Int] = [],
708+
dimensionOrder: [Int] = [],
709+
shapeDynamism: ShapeDynamism = .dynamicBound
710+
) {
711+
let newShape = shape.isEmpty ? [scalars.count] : shape
712+
precondition(scalars.count == elementCount(ofShape: newShape))
713+
self.init(scalars.withUnsafeMutableBufferPointer {
714+
AnyTensor(
715+
bytesNoCopy: $0.baseAddress!,
716+
shape: newShape,
717+
strides: strides,
718+
dimensionOrder: dimensionOrder,
719+
dataType: T.dataType,
720+
shapeDynamism: shapeDynamism
721+
)
722+
})
723+
}
724+
696725
/// Initializes a tensor with an array of scalar values.
697726
///
698727
/// - Parameters:

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class TensorTest: XCTestCase {
5656

5757
func testInitBytesNoCopy() {
5858
var data: [Float] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
59-
let tensor: Tensor<Float> = data.withUnsafeMutableBytes {
60-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3])
59+
let tensor = data.withUnsafeMutableBytes {
60+
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [2, 3])
6161
}
6262
// Modify the original data to make sure the tensor does not copy the data.
6363
data.indices.forEach { data[$0] += 1 }
@@ -73,8 +73,8 @@ class TensorTest: XCTestCase {
7373

7474
func testInitBytes() {
7575
var data: [Double] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
76-
let tensor: Tensor<Double> = data.withUnsafeMutableBytes {
77-
Tensor(bytes: $0.baseAddress!, shape: [2, 3])
76+
let tensor = data.withUnsafeMutableBytes {
77+
Tensor<Double>(bytes: $0.baseAddress!, shape: [2, 3])
7878
}
7979
// Modify the original data to make sure the tensor copies the data.
8080
data.indices.forEach { data[$0] += 1 }
@@ -91,14 +91,14 @@ class TensorTest: XCTestCase {
9191
func testInitData() {
9292
let dataArray: [Float] = [1.0, 2.0, 3.0, 4.0]
9393
let data = Data(bytes: dataArray, count: dataArray.count * MemoryLayout<Float>.size)
94-
let tensor: Tensor<Float> = Tensor(data: data, shape: [4])
94+
let tensor = Tensor<Float>(data: data, shape: [4])
9595
XCTAssertEqual(tensor.count, 4)
9696
XCTAssertEqual(try tensor.scalars(), dataArray)
9797
}
9898

9999
func testWithCustomStridesAndDimensionOrder() {
100100
let data: [Float] = [1.0, 2.0, 3.0, 4.0]
101-
let tensor: Tensor<Float> = Tensor(
101+
let tensor = Tensor<Float>(
102102
bytes: data.withUnsafeBytes { $0.baseAddress! },
103103
shape: [2, 2],
104104
strides: [1, 2],
@@ -113,8 +113,8 @@ class TensorTest: XCTestCase {
113113

114114
func testMutableBytes() {
115115
var data: [Int32] = [1, 2, 3, 4]
116-
let tensor: Tensor<Int32> = data.withUnsafeMutableBytes {
117-
Tensor(bytes: $0.baseAddress!, shape: [4])
116+
let tensor = data.withUnsafeMutableBytes {
117+
Tensor<Int32>(bytes: $0.baseAddress!, shape: [4])
118118
}
119119
XCTAssertNoThrow(try tensor.withUnsafeMutableBytes { buffer in
120120
for i in buffer.indices {
@@ -126,8 +126,8 @@ class TensorTest: XCTestCase {
126126

127127
func testInitWithTensor() throws {
128128
var data: [Int] = [10, 20, 30, 40]
129-
let tensor1: Tensor<Int> = data.withUnsafeMutableBytes {
130-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
129+
let tensor1 = data.withUnsafeMutableBytes {
130+
Tensor<Int>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
131131
}
132132
let tensor2 = Tensor(tensor1)
133133

@@ -157,8 +157,8 @@ class TensorTest: XCTestCase {
157157

158158
func testCopy() {
159159
var data: [Double] = [10.0, 20.0, 30.0, 40.0]
160-
let tensor1: Tensor<Double> = data.withUnsafeMutableBytes {
161-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
160+
let tensor1 = data.withUnsafeMutableBytes {
161+
Tensor<Double>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
162162
}
163163
let tensor2 = tensor1.copy()
164164

@@ -171,8 +171,8 @@ class TensorTest: XCTestCase {
171171

172172
func testResize() {
173173
var data: [Int] = [1, 2, 3, 4]
174-
let tensor: Tensor<Int> = data.withUnsafeMutableBytes {
175-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1])
174+
let tensor = data.withUnsafeMutableBytes {
175+
Tensor<Int>(bytesNoCopy: $0.baseAddress!, shape: [4, 1])
176176
}
177177
XCTAssertNoThrow(try tensor.resize(to: [2, 2]))
178178
XCTAssertEqual(tensor.dataType, .long)
@@ -185,38 +185,84 @@ class TensorTest: XCTestCase {
185185

186186
func testResizeError() {
187187
var data: [Int] = [1, 2, 3, 4]
188-
let tensor: Tensor<Int> = data.withUnsafeMutableBytes {
189-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1])
188+
let tensor = data.withUnsafeMutableBytes {
189+
Tensor<Int>(bytesNoCopy: $0.baseAddress!, shape: [4, 1])
190190
}
191191
XCTAssertThrowsError(try tensor.resize(to: [2, 3]))
192192
}
193193

194194
func testIsEqual() {
195195
var data: [Float] = [1.0, 2.0, 3.0, 4.0]
196-
let tensor1: Tensor<Float> = data.withUnsafeMutableBytes {
197-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
196+
let tensor1 = data.withUnsafeMutableBytes {
197+
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
198198
}
199199
let tensor2 = Tensor(tensor1)
200200
XCTAssertEqual(tensor1, tensor2)
201201
XCTAssertEqual(tensor2, tensor1)
202202

203203
var dataModified: [Float] = [1.0, 2.0, 3.0, 5.0]
204-
let tensor3: Tensor<Float> = dataModified.withUnsafeMutableBytes {
205-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
204+
let tensor3 = dataModified.withUnsafeMutableBytes {
205+
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [2, 2])
206206
}
207207
XCTAssertNotEqual(tensor1, tensor3)
208-
let tensor4: Tensor<Float> = data.withUnsafeMutableBytes {
209-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [4, 1])
208+
let tensor4 = data.withUnsafeMutableBytes {
209+
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [4, 1])
210210
}
211211
XCTAssertNotEqual(tensor1, tensor4)
212212
XCTAssertEqual(tensor1, tensor1)
213213
XCTAssertNotEqual(tensor4, tensor2)
214-
let tensor5: Tensor<Float> = data.withUnsafeMutableBytes {
215-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], shapeDynamism: .static)
214+
let tensor5 = data.withUnsafeMutableBytes {
215+
Tensor<Float>(bytesNoCopy: $0.baseAddress!, shape: [2, 2], shapeDynamism: .static)
216216
}
217217
XCTAssertEqual(tensor1, tensor5)
218218
}
219219

220+
func testInitScalarsNoCopyDefaultShape() throws {
221+
var data: [Float] = [1.0, 2.0, 3.0, 4.0]
222+
let tensor = Tensor(&data)
223+
224+
XCTAssertEqual(tensor.dataType, .float)
225+
XCTAssertEqual(tensor.shape, [4])
226+
XCTAssertEqual(tensor.strides, [1])
227+
XCTAssertEqual(tensor.dimensionOrder, [0])
228+
XCTAssertEqual(tensor.shapeDynamism, .dynamicBound)
229+
XCTAssertEqual(tensor.count, 4)
230+
data[2] = 42.0
231+
XCTAssertEqual(try tensor.scalars(), data)
232+
}
233+
234+
func testInitScalarsNoCopyWithExplicitParams() throws {
235+
var data: [Int] = [10, 20, 30, 40]
236+
let tensor = Tensor(
237+
&data,
238+
shape: [2, 2],
239+
strides: [2, 1],
240+
dimensionOrder: [1, 0],
241+
shapeDynamism: .static
242+
)
243+
XCTAssertEqual(tensor.dataType, .long)
244+
XCTAssertEqual(tensor.shape, [2, 2])
245+
XCTAssertEqual(tensor.strides, [2, 1])
246+
XCTAssertEqual(tensor.dimensionOrder, [1, 0])
247+
XCTAssertEqual(tensor.shapeDynamism, .static)
248+
XCTAssertEqual(tensor.count, 4)
249+
data = data.map { -$0 }
250+
XCTAssertEqual(try tensor.scalars(), data)
251+
}
252+
253+
func testInitScalarsBoolNoCopy() throws {
254+
var data: [Bool] = [true, false, true]
255+
let tensor = Tensor(&data)
256+
257+
XCTAssertEqual(tensor.dataType, .bool)
258+
XCTAssertEqual(tensor.shape, [3])
259+
XCTAssertEqual(tensor.strides, [1])
260+
XCTAssertEqual(tensor.dimensionOrder, [0])
261+
XCTAssertEqual(tensor.count, 3)
262+
data[1].toggle()
263+
XCTAssertEqual(try tensor.scalars(), data)
264+
}
265+
220266
func testInitScalarsUInt8() {
221267
let data: [UInt8] = [1, 2, 3, 4, 5, 6]
222268
let tensor = Tensor(data)

0 commit comments

Comments
 (0)