Skip to content

Commit 2740780

Browse files
committed
more tpu fixes
1 parent 7ca1959 commit 2740780

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

test/basic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,13 +988,15 @@ end
988988
@test Array(x) Array(y) ./ 2
989989
end
990990

991+
if !contains(string(Reactant.devices()[1]), "TPU")
991992
@testset "Hlo Cost Analysis" begin
992993
x_ra = Reactant.to_rarray(rand(4, 4))
993994
mul_comp = @compile x_ra * x_ra
994995
cost = Reactant.XLA.cost_analysis(mul_comp)
995996

996997
@test cost isa Reactant.XLA.HloCostAnalysisProperties
997998
end
999+
end
9981000

9991001
function fractional_idx(times, t)
10001002
n₂ = searchsortedfirst(times, t)

test/constructor.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,23 +105,23 @@ end
105105
rg = Reactant.to_rarray(g)
106106

107107
@jit update!(rg)
108-
@test convert(Array, rg.data) == [2.7, 1.59]
108+
@test convert(Array, rg.data) [2.7, 1.59]
109109

110110
rg = Reactant.to_rarray(g)
111111
res = @jit selfreturn(rg)
112-
@test convert(Array, res.data) == [3.14, 1.59]
113-
@test res.radius == 2.7
112+
@test convert(Array, res.data) [3.14, 1.59]
113+
@test res.radius 2.7
114114
@test typeof(res.radius) <: ConcreteRNumber
115115

116116
rg = Reactant.to_rarray(g)
117117

118118
@jit call_update!(rg)
119-
@test convert(Array, rg.data) == [2.7, 1.59]
119+
@test convert(Array, rg.data) [2.7, 1.59]
120120

121121
rg = Reactant.to_rarray(g)
122122
res = @jit call_selfreturn(rg)
123-
@test convert(Array, res.data) == [3.14, 1.59]
124-
@test res.radius == 2.7
123+
@test convert(Array, res.data) [3.14, 1.59]
124+
@test res.radius 2.7
125125
@test typeof(res.radius) <: ConcreteRNumber
126126
end
127127

@@ -131,7 +131,7 @@ end
131131

132132
rg = Reactant.to_rarray(g)
133133
res = @jit selfreturn(rg)
134-
@test convert(Array, res[1][].data) == [3.14, 1.59]
134+
@test convert(Array, res[1][].data) [3.14, 1.59]
135135
@test convert(Array, res[2][].data) [3.14, 1.59]
136136
@test res[1][].data == res[2][].data
137137
end

test/ops.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using SpecialFunctions: SpecialFunctions
1010
x = Reactant.to_rarray([1.0, -1.0])
1111
@test [1.0, 1.0] @jit Ops.abs(x)
1212

13+
if !contains(string(Reactant.devices()[1]), "TPU")
1314
x = Reactant.to_rarray([
1415
3.0+4im -3.0+4im
1516
3.0-4im -3.0-4im
@@ -18,6 +19,7 @@ using SpecialFunctions: SpecialFunctions
1819
5.0 5.0
1920
5.0 5.0
2021
] @jit Ops.abs(x)
22+
end
2123
end
2224

2325
@testset "add" begin
@@ -95,6 +97,7 @@ end
9597
@test cholesky(Array(x)).U @jit g1(x)
9698
@test transpose(cholesky(Array(x)).U) @jit g2(x)
9799

100+
if !contains(string(Reactant.devices()[1]), "TPU")
98101
x = Reactant.to_rarray(
99102
[
100103
10.0+0.0im 2.0-3.0im 3.0-4.0im
@@ -105,6 +108,7 @@ end
105108

106109
@test cholesky(Array(x)).U @jit g1(x)
107110
@test adjoint(cholesky(Array(x)).U) @jit g2(x)
111+
end
108112
end
109113

110114
@testset "clamp" begin
@@ -135,11 +139,12 @@ end
135139
),
136140
]
137141
x = Reactant.to_rarray([1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.0])
138-
@test [3.0, 3.0, 3.3, 4.4, 5.5, 6.6, 7.0, 7.0, 7.0, 7.0] ==
142+
@test [3.0, 3.0, 3.3, 4.4, 5.5, 6.6, 7.0, 7.0, 7.0, 7.0]
139143
@jit Ops.clamp(_min, x, _max)
140144
end
141145
end
142146

147+
if !contains(string(Reactant.devices()[1]), "TPU")
143148
@testset "complex" begin
144149
x = Reactant.to_rarray(1.1; track_numbers=true)
145150
y = Reactant.to_rarray(2.2; track_numbers=true)
@@ -149,6 +154,7 @@ end
149154
y = Reactant.to_rarray([5.5, 6.6, -7.7, -8.8])
150155
@test [1.1 + 5.5im, 2.2 + 6.6im, 3.3 - 7.7im, 4.4 - 8.8im] @jit Ops.complex(x, y)
151156
end
157+
end
152158

153159
@testset "constant" begin
154160
for x in [[1, 2, 3], [1.1, 2.2, 3.3], [1.1 + 2.2im, 3.3 + 4.4im, 5.5 + 6.6im]]

0 commit comments

Comments
 (0)