Skip to content

Commit 7a1dc2f

Browse files
authored
Make Tensor data accessors non-throwing.
Differential Revision: D79381680 Pull Request resolved: #13087
1 parent 4266820 commit 7a1dc2f

File tree

3 files changed

+114
-121
lines changed

3 files changed

+114
-121
lines changed

docs/source/using-executorch-ios.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ let inputTensor = Tensor<Float>(&imageBuffer, shape: [1, 3, 224, 224])
246246
let outputTensor: Tensor<Float> = try module.forward(inputTensor)[0].tensor()!
247247
248248
// Copy the tensor data into logits array for easier access.
249-
let logits = try outputTensor.scalars()
249+
let logits = outputTensor.scalars()
250250
251251
// Use logits...
252252
```
@@ -444,11 +444,11 @@ Swift:
444444
let tensor = Tensor<Float>([1.0, 2.0, 3.0, 4.0], shape: [2, 2])
445445
446446
// Get data copy as a Swift array.
447-
let scalars = try tensor.scalars()
447+
let scalars = tensor.scalars()
448448
print("All scalars: \(scalars)") // [1.0, 2.0, 3.0, 4.0]
449449
450450
// Access data via a buffer pointer.
451-
try tensor.withUnsafeBytes { buffer in
451+
tensor.withUnsafeBytes { buffer in
452452
print("First float element: \(buffer.first ?? 0.0)")
453453
}
454454
@@ -482,7 +482,7 @@ Swift:
482482
let tensor = Tensor<Float>([1.0, 2.0, 3.0, 4.0], shape: [2, 2])
483483
484484
// Modify the tensor's data in place.
485-
try tensor.withUnsafeMutableBytes { buffer in
485+
tensor.withUnsafeMutableBytes { buffer in
486486
buffer[1] = 200.0
487487
}
488488
// tensor's data is now [1.0, 200.0, 3.0, 4.0]

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -770,35 +770,29 @@ public final class Tensor<T: Scalar>: Equatable {
770770
/// - Parameter body: A closure that receives an `UnsafeBufferPointer<T>` bound to the tensor’s data.
771771
/// - Returns: The value returned by `body`.
772772
/// - Throws: Any error thrown by `body`.
773-
public func withUnsafeBytes<R>(_ body: (UnsafeBufferPointer<T>) throws -> R) throws -> R {
774-
var result: Result<R, Error>?
775-
anyTensor.bytes { pointer, count, _ in
776-
result = Result { try body(
777-
UnsafeBufferPointer(
778-
start: pointer.assumingMemoryBound(to: T.self),
779-
count: count
780-
)
781-
) }
773+
public func withUnsafeBytes<R>(_ body: (UnsafeBufferPointer<T>) throws -> R) rethrows -> R {
774+
try withoutActuallyEscaping(body) { body in
775+
var result: Result<R, Error>?
776+
anyTensor.bytes { pointer, count, _ in
777+
result = Result { try body(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: T.self), count: count)) }
778+
}
779+
return try result!.get()
782780
}
783-
return try result!.get()
784781
}
785782

786783
/// Calls the closure with a typed, mutable buffer pointer over the tensor’s elements.
787784
///
788785
/// - Parameter body: A closure that receives an `UnsafeMutableBufferPointer<T>` bound to the tensor’s data.
789786
/// - Returns: The value returned by `body`.
790787
/// - Throws: Any error thrown by `body`.
791-
public func withUnsafeMutableBytes<R>(_ body: (UnsafeMutableBufferPointer<T>) throws -> R) throws -> R {
792-
var result: Result<R, Error>?
793-
anyTensor.mutableBytes { pointer, count, _ in
794-
result = Result { try body(
795-
UnsafeMutableBufferPointer(
796-
start: pointer.assumingMemoryBound(to: T.self),
797-
count: count
798-
)
799-
) }
788+
public func withUnsafeMutableBytes<R>(_ body: (UnsafeMutableBufferPointer<T>) throws -> R) rethrows -> R {
789+
try withoutActuallyEscaping(body) { body in
790+
var result: Result<R, Error>?
791+
anyTensor.mutableBytes { pointer, count, _ in
792+
result = Result { try body(UnsafeMutableBufferPointer(start: pointer.assumingMemoryBound(to: T.self), count: count)) }
793+
}
794+
return try result!.get()
800795
}
801-
return try result!.get()
802796
}
803797

804798
/// Resizes the tensor to a new shape.
@@ -830,9 +824,8 @@ public extension Tensor {
830824
/// Returns the tensor's elements as an array of scalars.
831825
///
832826
/// - Returns: An array of scalars of type `T`.
833-
/// - Throws: An error if the underlying data cannot be accessed.
834-
func scalars() throws -> [T] {
835-
try withUnsafeBytes { Array($0) }
827+
func scalars() -> [T] {
828+
withUnsafeBytes { Array($0) }
836829
}
837830
}
838831

0 commit comments

Comments
 (0)