Skip to content

Commit e02c136

Browse files
Create Value with Tensor. (#9682)
Summary: #8364 Reviewed By: mergennachin Differential Revision: D71914665 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 1bdee96 commit e02c136

File tree

3 files changed

+47
-0
lines changed

3 files changed

+47
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchValue.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,35 @@ __attribute__((deprecated("This API is experimental.")))
4747
*/
4848
@property(nonatomic, readonly) ExecuTorchValueTag tag;
4949

50+
/**
51+
* The tensor value if the tag is ExecuTorchValueTagTensor.
52+
*
53+
* @return A Tensor instance or nil.
54+
*/
55+
@property(nullable, nonatomic, readonly) ExecuTorchTensor *tensorValue NS_SWIFT_NAME(tensor);
56+
5057
/**
5158
* Returns YES if the value is of type None.
5259
*
5360
* @return A BOOL indicating whether the value is None.
5461
*/
5562
@property(nonatomic, readonly) BOOL isNone;
5663

64+
/**
65+
* Returns YES if the value is a Tensor.
66+
*
67+
* @return A BOOL indicating whether the value is a Tensor.
68+
*/
69+
@property(nonatomic, readonly) BOOL isTensor;
70+
71+
/**
72+
* Creates an instance encapsulating a Tensor.
73+
*
74+
* @param value An ExecuTorchTensor instance.
75+
* @return A new ExecuTorchValue instance with a tag of ExecuTorchValueTagTensor.
76+
*/
77+
+ (instancetype)valueWithTensor:(ExecuTorchTensor *)value NS_SWIFT_NAME(init(_:));
78+
5779
@end
5880

5981
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchValue.mm

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#import "ExecuTorchValue.h"
1010

11+
#import <executorch/runtime/platform/assert.h>
12+
1113
@interface ExecuTorchValue ()
1214

1315
- (instancetype)initWithTag:(ExecuTorchValueTag)tag
@@ -20,6 +22,11 @@ @implementation ExecuTorchValue {
2022
id _value;
2123
}
2224

25+
+ (instancetype)valueWithTensor:(ExecuTorchTensor *)value {
26+
ET_CHECK(value);
27+
return [[ExecuTorchValue alloc] initWithTag:ExecuTorchValueTagTensor value:value];
28+
}
29+
2330
- (instancetype)init {
2431
return [self initWithTag:ExecuTorchValueTagNone value:nil];
2532
}
@@ -37,8 +44,16 @@ - (ExecuTorchValueTag)tag {
3744
return _tag;
3845
}
3946

47+
- (nullable ExecuTorchTensor *)tensorValue {
48+
return self.isTensor ? _value : nil;
49+
}
50+
4051
- (BOOL)isNone {
4152
return _tag == ExecuTorchValueTagNone;
4253
}
4354

55+
- (BOOL)isTensor {
56+
return _tag == ExecuTorchValueTagTensor;
57+
}
58+
4459
@end

extension/apple/ExecuTorch/__tests__/ValueTest.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,14 @@ class ValueTest: XCTestCase {
1515
let value = Value()
1616
XCTAssertTrue(value.isNone)
1717
}
18+
19+
func testTensor() {
20+
var data: [Float] = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
21+
let tensor = data.withUnsafeMutableBytes {
22+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
23+
}
24+
let value = Value(tensor)
25+
XCTAssertTrue(value.isTensor)
26+
XCTAssertEqual(value.tensor, tensor)
27+
}
1828
}

0 commit comments

Comments
 (0)