Skip to content

Commit fb81635

Browse files
committed
Before commenting out subscript
1 parent 754b49e commit fb81635

File tree

4 files changed

+33
-122
lines changed

4 files changed

+33
-122
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ if(NOT X10_FOUND AND NOT USE_BUNDLED_X10)
189189
COMMAND
190190
rm -rf <SOURCE_DIR>/bazel-bin # ${CMAKE_COMMAND} -E rm -Rrf <SOURCE_DIR>/bazel-bin
191191
COMMAND
192-
bazel build ${VISIBILITY_FLAGS} -c opt --define framework_shared_object=false //tensorflow/compiler/tf2xla/xla_tensor:x10 --nocheck_visibility
192+
bazel build ${VISIBILITY_FLAGS} -c opt --define framework_shared_object=false //tensorflow:tensorflow //tensorflow/compiler/tf2xla/xla_tensor:x10 --nocheck_visibility
193193
COMMAND
194194
bazel shutdown
195195
INSTALL_COMMAND

Sources/TensorFlow/Layers/Normalization.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import _Differentiation
2424
/// - scale: The tensor to be applied to normalized tensor.
2525
/// - varianceEpsilon: The small number to avoid dividing by 0.
2626
@differentiable(reverse, wrt: (input, mean, variance, offset, scale))
27-
private func normalize<Scalar: TensorFlowFloatingPoint>(
27+
private func normalize<Scalar: TensorFlowFloatingPoint> (
2828
_ input: Tensor<Scalar>,
2929
mean: Tensor<Scalar>,
3030
variance: Tensor<Scalar>,

Sources/TensorFlow/Layers/Sequential.swift

Lines changed: 17 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ import _Differentiation
4747
/// ````
4848
public struct Sequential<Layer1: Module, Layer2: Layer>: Module
4949
where
50-
Layer1.Output == Layer2.Input,
51-
Layer1.TangentVector.VectorSpaceScalar == Layer2.TangentVector.VectorSpaceScalar
50+
Layer1.Output == Layer2.Input
5251
{
5352
public var layer1: Layer1
5453
public var layer2: Layer2
@@ -77,42 +76,28 @@ extension Sequential: Layer where Layer1: Layer {
7776
/// A layer that sequentially composes 3 layers.
7877
public typealias Sequential3<L1: Module, L2: Layer, L3: Layer> = Sequential<L1, Sequential<L2, L3>>
7978
where
80-
L1.Output == L2.Input, L2.Output == L3.Input,
81-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
82-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar
79+
L1.Output == L2.Input, L2.Output == L3.Input
8380

8481
/// A layer that sequentially composes 4 layers.
8582
public typealias Sequential4<L1: Module, L2: Layer, L3: Layer, L4: Layer> = Sequential<
8683
L1, Sequential<L2, Sequential<L3, L4>>
8784
>
8885
where
89-
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input,
90-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
91-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
92-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar
86+
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input
9387

9488
/// A layer that sequentially composes 5 layers.
9589
public typealias Sequential5<L1: Module, L2: Layer, L3: Layer, L4: Layer, L5: Layer> = Sequential<
9690
L1, Sequential<L2, Sequential<L3, Sequential<L4, L5>>>
9791
>
9892
where
99-
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input,
100-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
101-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
102-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
103-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar
93+
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input
10494

10595
/// A layer that sequentially composes 6 layers.
10696
public typealias Sequential6<L1: Module, L2: Layer, L3: Layer, L4: Layer, L5: Layer, L6: Layer> =
10797
Sequential<L1, Sequential<L2, Sequential<L3, Sequential<L4, Sequential<L5, L6>>>>>
10898
where
10999
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input,
110-
L5.Output == L6.Input,
111-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
112-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
113-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
114-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
115-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar
100+
L5.Output == L6.Input
116101

117102
/// A layer that sequentially composes 7 layers.
118103
public typealias Sequential7<
@@ -122,13 +107,7 @@ public typealias Sequential7<
122107
>
123108
where
124109
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input,
125-
L5.Output == L6.Input, L6.Output == L7.Input,
126-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
127-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
128-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
129-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
130-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
131-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar
110+
L5.Output == L6.Input, L6.Output == L7.Input
132111

133112
/// A layer that sequentially composes 8 layers.
134113
public typealias Sequential8<
@@ -139,14 +118,7 @@ public typealias Sequential8<
139118
>
140119
where
141120
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input,
142-
L5.Output == L6.Input, L6.Output == L7.Input, L7.Output == L8.Input,
143-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
144-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
145-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
146-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
147-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
148-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar,
149-
L7.TangentVector.VectorSpaceScalar == L8.TangentVector.VectorSpaceScalar
121+
L5.Output == L6.Input, L6.Output == L7.Input, L7.Output == L8.Input
150122

151123
/// A layer that sequentially composes 9 layers.
152124
public typealias Sequential9<
@@ -162,15 +134,7 @@ public typealias Sequential9<
162134
>
163135
where
164136
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input,
165-
L5.Output == L6.Input, L6.Output == L7.Input, L7.Output == L8.Input, L8.Output == L9.Input,
166-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
167-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
168-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
169-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
170-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
171-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar,
172-
L7.TangentVector.VectorSpaceScalar == L8.TangentVector.VectorSpaceScalar,
173-
L8.TangentVector.VectorSpaceScalar == L9.TangentVector.VectorSpaceScalar
137+
L5.Output == L6.Input, L6.Output == L7.Input, L7.Output == L8.Input, L8.Output == L9.Input
174138

175139
/// A layer that sequentially composes 10 layers.
176140
public typealias Sequential10<
@@ -191,16 +155,7 @@ public typealias Sequential10<
191155
where
192156
L1.Output == L2.Input, L2.Output == L3.Input, L3.Output == L4.Input, L4.Output == L5.Input,
193157
L5.Output == L6.Input, L6.Output == L7.Input, L7.Output == L8.Input, L8.Output == L9.Input,
194-
L9.Output == L10.Input,
195-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
196-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
197-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
198-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
199-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
200-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar,
201-
L7.TangentVector.VectorSpaceScalar == L8.TangentVector.VectorSpaceScalar,
202-
L8.TangentVector.VectorSpaceScalar == L9.TangentVector.VectorSpaceScalar,
203-
L9.TangentVector.VectorSpaceScalar == L10.TangentVector.VectorSpaceScalar
158+
L9.Output == L10.Input
204159

205160
@resultBuilder
206161
public struct LayerBuilder {
@@ -217,9 +172,7 @@ public struct LayerBuilder {
217172
-> Sequential<L1, Sequential<L2, L3>>
218173
where
219174
L1.Output == L2.Input,
220-
L2.Output == L3.Input,
221-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
222-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar
175+
L2.Output == L3.Input
223176
{
224177
Sequential(l1, Sequential(l2, l3))
225178
}
@@ -234,10 +187,7 @@ public struct LayerBuilder {
234187
where
235188
L1.Output == L2.Input,
236189
L2.Output == L3.Input,
237-
L3.Output == L4.Input,
238-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
239-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
240-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar
190+
L3.Output == L4.Input
241191
{
242192
Sequential(l1, Sequential(l2, Sequential(l3, l4)))
243193
}
@@ -254,11 +204,7 @@ public struct LayerBuilder {
254204
L1.Output == L2.Input,
255205
L2.Output == L3.Input,
256206
L3.Output == L4.Input,
257-
L4.Output == L5.Input,
258-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
259-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
260-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
261-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar
207+
L4.Output == L5.Input
262208
{
263209
Sequential(l1, Sequential(l2, Sequential(l3, Sequential(l4, l5))))
264210
}
@@ -277,12 +223,7 @@ public struct LayerBuilder {
277223
L2.Output == L3.Input,
278224
L3.Output == L4.Input,
279225
L4.Output == L5.Input,
280-
L5.Output == L6.Input,
281-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
282-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
283-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
284-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
285-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar
226+
L5.Output == L6.Input
286227
{
287228
Sequential(l1, Sequential(l2, Sequential(l3, Sequential(l4, Sequential(l5, l6)))))
288229
}
@@ -305,13 +246,7 @@ public struct LayerBuilder {
305246
L3.Output == L4.Input,
306247
L4.Output == L5.Input,
307248
L5.Output == L6.Input,
308-
L6.Output == L7.Input,
309-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
310-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
311-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
312-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
313-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
314-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar
249+
L6.Output == L7.Input
315250
{
316251
Sequential(
317252
l1, Sequential(l2, Sequential(l3, Sequential(l4, Sequential(l5, Sequential(l6, l7))))))
@@ -340,14 +275,7 @@ public struct LayerBuilder {
340275
L4.Output == L5.Input,
341276
L5.Output == L6.Input,
342277
L6.Output == L7.Input,
343-
L7.Output == L8.Input,
344-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
345-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
346-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
347-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
348-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
349-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar,
350-
L7.TangentVector.VectorSpaceScalar == L8.TangentVector.VectorSpaceScalar
278+
L7.Output == L8.Input
351279
{
352280
Sequential(
353281
l1,
@@ -383,15 +311,7 @@ public struct LayerBuilder {
383311
L5.Output == L6.Input,
384312
L6.Output == L7.Input,
385313
L7.Output == L8.Input,
386-
L8.Output == L9.Input,
387-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
388-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
389-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
390-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
391-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
392-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar,
393-
L7.TangentVector.VectorSpaceScalar == L8.TangentVector.VectorSpaceScalar,
394-
L8.TangentVector.VectorSpaceScalar == L9.TangentVector.VectorSpaceScalar
314+
L8.Output == L9.Input
395315
{
396316
Sequential(
397317
l1,
@@ -437,16 +357,7 @@ public struct LayerBuilder {
437357
L6.Output == L7.Input,
438358
L7.Output == L8.Input,
439359
L8.Output == L9.Input,
440-
L9.Output == L10.Input,
441-
L1.TangentVector.VectorSpaceScalar == L2.TangentVector.VectorSpaceScalar,
442-
L2.TangentVector.VectorSpaceScalar == L3.TangentVector.VectorSpaceScalar,
443-
L3.TangentVector.VectorSpaceScalar == L4.TangentVector.VectorSpaceScalar,
444-
L4.TangentVector.VectorSpaceScalar == L5.TangentVector.VectorSpaceScalar,
445-
L5.TangentVector.VectorSpaceScalar == L6.TangentVector.VectorSpaceScalar,
446-
L6.TangentVector.VectorSpaceScalar == L7.TangentVector.VectorSpaceScalar,
447-
L7.TangentVector.VectorSpaceScalar == L8.TangentVector.VectorSpaceScalar,
448-
L8.TangentVector.VectorSpaceScalar == L9.TangentVector.VectorSpaceScalar,
449-
L9.TangentVector.VectorSpaceScalar == L10.TangentVector.VectorSpaceScalar
360+
L9.Output == L10.Input
450361
{
451362
Sequential(
452363
l1,

Tests/TensorFlowTests/TensorAutoDiffTests.swift

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,15 @@ final class TensorAutoDiffTests: XCTestCase {
207207
XCTAssertTrue(
208208
(Tensor<Float>(1), Tensor<Float>(1))
209209
== gradient(at: Tensor<Float>(0), Tensor<Float>(0), in: f))
210-
XCTAssertTrue(([1], [1]) == pullback(at: [1], [10], in: f)([1]))
210+
XCTAssertTrue(([1], [1]) == pullback(at: [1], [10], of: f)([1]))
211211
}
212212

213213
func testSubtract() {
214214
func f(a: Tensor<Float>, b: Tensor<Float>) -> Tensor<Float> { a - b }
215215
XCTAssertTrue(
216216
(Tensor<Float>(1), Tensor<Float>(-1))
217217
== gradient(at: Tensor<Float>(0), Tensor<Float>(0), in: f))
218-
XCTAssertTrue(([1], [-1]) == pullback(at: [1], [10], in: f)([1]))
218+
XCTAssertTrue(([1], [-1]) == pullback(at: [1], [10], of: f)([1]))
219219
}
220220

221221
func testMultiply() {
@@ -226,21 +226,21 @@ final class TensorAutoDiffTests: XCTestCase {
226226

227227
func testDivide() {
228228
func f(a: Tensor<Float>, b: Tensor<Float>) -> Tensor<Float> { a / b }
229-
XCTAssertTrue(([0.1], [-0.01]) == pullback(at: [1], [10], in: f)([1]))
229+
XCTAssertTrue(([0.1], [-0.01]) == pullback(at: [1], [10], of: f)([1]))
230230
}
231231

232232
func testMatmul() {
233233
func f(a: Tensor<Float>, b: Tensor<Float>) -> Tensor<Float> { matmul(a, b) }
234234
let v = Tensor<Float>(ones: [1, 1])
235-
XCTAssertTrue(([[0]], [[0]]) == pullback(at: [[0]], [[0]], in: f)(v))
236-
XCTAssertTrue(([[10]], [[1]]) == pullback(at: [[1]], [[10]], in: f)(v))
235+
XCTAssertTrue(([[0]], [[0]]) == pullback(at: [[0]], [[0]], of: f)(v))
236+
XCTAssertTrue(([[10]], [[1]]) == pullback(at: [[1]], [[10]], of: f)(v))
237237
}
238238

239239
func testDot() {
240240
func f(a: Tensor<Float>, b: Tensor<Float>) -> Tensor<Float> { a b }
241241
let v = Tensor<Float>(ones: [1, 1])
242-
XCTAssertTrue(([[0]], [[0]]) == pullback(at: [[0]], [[0]], in: f)(v))
243-
XCTAssertTrue(([[10]], [[1]]) == pullback(at: [[1]], [[10]], in: f)(v))
242+
XCTAssertTrue(([[0]], [[0]]) == pullback(at: [[0]], [[0]], of: f)(v))
243+
XCTAssertTrue(([[10]], [[1]]) == pullback(at: [[1]], [[10]], of: f)(v))
244244
}
245245

246246
func testNegate() {
@@ -509,15 +509,15 @@ final class TensorAutoDiffTests: XCTestCase {
509509
func testExpandingShape() {
510510
func f1(a: Tensor<Float>) -> Tensor<Float> { a.expandingShape(at: 0).squared() }
511511
func f2(a: Tensor<Float>) -> Tensor<Float> { a.squared().expandingShape(at: 0) }
512-
XCTAssertEqual(pullback(at: [3, 5], in: f1)([[1, 1]]), [6, 10])
513-
XCTAssertEqual(pullback(at: [3, 5], in: f2)([[1, 1]]), [6, 10])
512+
XCTAssertEqual(pullback(at: [3, 5], of: f1)([[1, 1]]), [6, 10])
513+
XCTAssertEqual(pullback(at: [3, 5], of: f2)([[1, 1]]), [6, 10])
514514
}
515515

516516
func testSqueezingShape() {
517517
func f1(a: Tensor<Float>) -> Tensor<Float> { a.squeezingShape(at: 0).squared() }
518518
func f2(a: Tensor<Float>) -> Tensor<Float> { a.squared().squeezingShape(at: 0) }
519-
XCTAssertEqual(pullback(at: [[3, 5]], in: f1)([1, 1]), [[6, 10]])
520-
XCTAssertEqual(pullback(at: [[3, 5]], in: f2)([1, 1]), [[6, 10]])
519+
XCTAssertEqual(pullback(at: [[3, 5]], of: f1)([1, 1]), [[6, 10]])
520+
XCTAssertEqual(pullback(at: [[3, 5]], of: f2)([1, 1]), [[6, 10]])
521521
}
522522

523523
func testTiled() {
@@ -536,8 +536,8 @@ final class TensorAutoDiffTests: XCTestCase {
536536
func f2(a: Tensor<Float>) -> Tensor<Float> {
537537
a.squared().reshaped(toShape: Tensor<Int32>([2, 1]))
538538
}
539-
XCTAssertEqual(pullback(at: [[3, 5]], in: f1)([[1], [1]]), [[6, 10]])
540-
XCTAssertEqual(pullback(at: [[3, 5]], in: f2)([[1], [1]]), [[6, 10]])
539+
XCTAssertEqual(pullback(at: [[3, 5]], of: f1)([[1], [1]]), [[6, 10]])
540+
XCTAssertEqual(pullback(at: [[3, 5]], of: f2)([[1], [1]]), [[6, 10]])
541541
}
542542

543543
func testReshaped() {
@@ -651,7 +651,7 @@ final class TensorAutoDiffTests: XCTestCase {
651651
a = a + x
652652
return a + x
653653
}
654-
XCTAssertEqual(Tensor([4, 4]), pullback(at: Tensor([4, 5]), in: foo)([1, 1]))
654+
XCTAssertEqual(Tensor([4, 4]), pullback(at: Tensor([4, 5]), of: foo)([1, 1]))
655655

656656
func bar(x: Tensor<Float>) -> Tensor<Float> {
657657
var a = x

0 commit comments

Comments
 (0)