@@ -1297,7 +1297,7 @@ extension Tensor {
1297
1297
}
1298
1298
1299
1299
@inlinable
1300
- @differentiable ( reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
1300
+ // @differentiable(reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
1301
1301
internal subscript( _ indexPath: IndexPath ) -> Tensor {
1302
1302
get {
1303
1303
let device = self . device
@@ -1323,7 +1323,7 @@ extension Tensor {
1323
1323
}
1324
1324
1325
1325
@inlinable
1326
- @differentiable ( reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
1326
+ // @differentiable(reverse, wrt: self where Scalar: TensorFlowFloatingPoint)
1327
1327
public subscript( _ ranges: TensorRangeExpression ... ) -> Tensor {
1328
1328
get {
1329
1329
return self [ { IndexPath ( { ranges. map { $0. tensorRange } } ( ) ) } ( ) ]
@@ -1334,27 +1334,27 @@ extension Tensor {
1334
1334
}
1335
1335
}
1336
1336
1337
- extension Tensor where Scalar : TensorFlowFloatingPoint {
1338
- @usableFromInline
1339
- @derivative ( of: subscript)
1340
- internal func _vjpSubscript(
1341
- _ indexPath: IndexPath
1342
- ) -> ( value: Tensor , pullback: ( Tensor ) -> Tensor ) {
1343
- return (
1344
- self [ indexPath] ,
1345
- { [ shape = shapeTensor] v in
1346
- _Raw. stridedSliceGrad (
1347
- shape: shape, begin: Tensor < Int32 > ( indexPath. begin, on: device) ,
1348
- end: Tensor < Int32 > ( indexPath. end, on: device) ,
1349
- strides: Tensor < Int32 > ( indexPath. strides, on: device) , dy: v,
1350
- beginMask: indexPath. beginMask,
1351
- endMask: indexPath. endMask, ellipsisMask: indexPath. ellipsisMask,
1352
- newAxisMask: indexPath. newAxisMask,
1353
- shrinkAxisMask: indexPath. squeezeAxisMask)
1354
- }
1355
- )
1356
- }
1357
- }
1337
+ // extension Tensor {
1338
+ // @usableFromInline
1339
+ // @derivative(of: subscript)
1340
+ // internal func _vjpSubscript(
1341
+ // _ indexPath: IndexPath
1342
+ // ) -> (value: Tensor, pullback: (Tensor) -> Tensor) {
1343
+ // return (
1344
+ // self[indexPath],
1345
+ // { [shape = shapeTensor] v in
1346
+ // _Raw.stridedSliceGrad(
1347
+ // shape: shape, begin: Tensor<Int32>(indexPath.begin, on: device),
1348
+ // end: Tensor<Int32>(indexPath.end, on: device),
1349
+ // strides: Tensor<Int32>(indexPath.strides, on: device), dy: v,
1350
+ // beginMask: indexPath.beginMask,
1351
+ // endMask: indexPath.endMask, ellipsisMask: indexPath.ellipsisMask,
1352
+ // newAxisMask: indexPath.newAxisMask,
1353
+ // shrinkAxisMask: indexPath.squeezeAxisMask)
1354
+ // }
1355
+ // )
1356
+ // }
1357
+ // }
1358
1358
1359
1359
extension Tensor . IndexPath {
1360
1360
@inlinable
0 commit comments