diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm index 4456c023185..2cf62f9be8b 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm +++ b/extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm @@ -137,7 +137,6 @@ - (BOOL)isEqualToTensor:(nullable ExecuTorchTensor *)other { [self.shape isEqual:other.shape] && [self.dimensionOrder isEqual:other.dimensionOrder] && [self.strides isEqual:other.strides] && - self.shapeDynamism == other.shapeDynamism && (data && otherData ? std::memcmp(data, otherData, size) == 0 : data == otherData); } diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index 052b84ae5f8..12427b43b7c 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -208,6 +208,10 @@ class TensorTest: XCTestCase { XCTAssertTrue(tensor1.isEqual(tensor1)) XCTAssertFalse(tensor1.isEqual(NSString(string: "Not a tensor"))) XCTAssertFalse(tensor4.isEqual(tensor2.copy())) + let tensor5 = data.withUnsafeMutableBytes { + Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 2], dataType: .float, shapeDynamism: .static) + } + XCTAssertTrue(tensor1.isEqual(tensor5)) } func testInitScalarsUInt8() {