File tree Expand file tree Collapse file tree 3 files changed +57
-0
lines changed
extension/apple/ExecuTorch Expand file tree Collapse file tree 3 files changed +57
-0
lines changed Original file line number Diff line number Diff line change @@ -544,4 +544,26 @@ __attribute__((deprecated("This API is experimental.")))
544544
545545@end
546546
547+ @interface ExecuTorchTensor (Scalar)
548+
549+ /* *
550+ * Initializes a tensor with a single scalar value and a specified data type.
551+ *
552+ * @param scalar An NSNumber representing the scalar value.
553+ * @param dataType An ExecuTorchDataType value specifying the element type.
554+ * @return An initialized ExecuTorchTensor instance representing the scalar.
555+ */
556+ - (instancetype )initWithScalar : (NSNumber *)scalar
557+ dataType : (ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:));
558+
559+ /* *
560+ * Initializes a tensor with a single scalar value, automatically deducing its data type.
561+ *
562+ * @param scalar An NSNumber representing the scalar value.
563+ * @return An initialized ExecuTorchTensor instance representing the scalar.
564+ */
565+ - (instancetype )initWithScalar : (NSNumber *)scalar NS_SWIFT_NAME(init(_:));
566+
567+ @end
568+
547569NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change @@ -455,3 +455,26 @@ - (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars {
455455}
456456
457457@end
458+
459+ @implementation ExecuTorchTensor (Scalar)
460+
461+ - (instancetype )initWithScalar : (NSNumber *)scalar
462+ dataType : (ExecuTorchDataType)dataType {
463+ return [self initWithScalars: @[scalar]
464+ shape: @[]
465+ strides: @[]
466+ dimensionOrder: @[]
467+ dataType: dataType
468+ shapeDynamism: ExecuTorchShapeDynamismDynamicBound];
469+ }
470+
471+ - (instancetype )initWithScalar : (NSNumber *)scalar {
472+ return [self initWithScalars: @[scalar]
473+ shape: @[]
474+ strides: @[]
475+ dimensionOrder: @[]
476+ dataType: static_cast <ExecuTorchDataType>(utils: :deduceType (scalar))
477+ shapeDynamism: ExecuTorchShapeDynamismDynamicBound];
478+ }
479+
480+ @end
Original file line number Diff line number Diff line change @@ -392,4 +392,16 @@ class TensorTest: XCTestCase {
392392 XCTAssertEqual ( Array ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: UInt . self) , count: count) ) , data)
393393 }
394394 }
395+
396+ func testInitFloat( ) {
397+ let tensor = Tensor ( Float ( 42.0 ) as NSNumber )
398+ XCTAssertEqual ( tensor. dataType, . float)
399+ XCTAssertEqual ( tensor. shape, [ ] )
400+ XCTAssertEqual ( tensor. strides, [ ] )
401+ XCTAssertEqual ( tensor. dimensionOrder, [ ] )
402+ XCTAssertEqual ( tensor. count, 1 )
403+ tensor. bytes { pointer, count, dataType in
404+ XCTAssertEqual ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: Float . self) , count: count) . first, 42.0 )
405+ }
406+ }
395407}
You can’t perform that action at this time.
0 commit comments