Skip to content

Commit 2767797

Browse files
committed
Merge branch 'joshs-working-branch' into develop
2 parents e044edb + 20c0ec7 commit 2767797

28 files changed

+443
-118
lines changed

dsa2000_cal/benchmarking/calibration/benchmark_JtJg_calculation_BTC_cpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -152,13 +154,11 @@ def main():
152154
# dt = (t1 - t0) / 1
153155
# dsa_logger.info(f"BTC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
154156

155-
# Fit line to data using scipy
156157
time_array = np.array(time_array)
157158
d_array = np.array(d_array)
158-
from scipy.optimize import curve_fit
159159

160-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
161-
dsa_logger.info(f"BTC: Fit: {popt}")
160+
a, b, c = fit_timings(d_array, time_array)
161+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
162162

163163

164164
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_JtJg_calculation_TBC_cpu.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -154,18 +156,14 @@ def main():
154156
# dt = (t1 - t0) / 1
155157
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
156158

157-
# Fit line to data using scipy
158159
time_array = np.array(time_array)
159-
d_array = np.array(d_array)
160-
from scipy.optimize import curve_fit
161-
162-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
163-
dsa_logger.info(f"TBC: Fit: {popt}")
164-
165160
shard_time_array = np.array(shard_time_array)
161+
d_array = np.array(d_array)
166162

167-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array)
168-
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
163+
a, b, c = fit_timings(d_array, time_array)
164+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
165+
a, b, c = fit_timings(d_array, shard_time_array)
166+
dsa_logger.info(f"Fit (sharded): t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
169167

170168

171169
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_JtJg_calculation_TBC_gpu.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cuda'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68

@@ -101,7 +103,7 @@ def main():
101103
t1 = time.time()
102104
dt = (t1 - t0) / 10
103105
dsa_logger.info(f"TBC: J^T.J.g (Full avg.): CPU D={D}: {dt}")
104-
d_array.append(dt)
106+
d_array.append(D)
105107
shard_time_array.append(dt)
106108

107109
data = prepare_data(D, Ts=4, Tm=1, Cs=4, Cm=1)
@@ -123,14 +125,11 @@ def main():
123125
# dt = (t1 - t0) / 1
124126
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
125127

126-
# Fit line to data using scipy
127-
d_array = np.array(d_array)
128-
from scipy.optimize import curve_fit
129-
130128
shard_time_array = np.array(shard_time_array)
129+
d_array = np.array(d_array)
131130

132-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array,bounds=([0., 0.5], [np.inf, 1.5]))
133-
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
131+
a, b, c = fit_timings(d_array, shard_time_array)
132+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
134133

135134

136135
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_JtR_calculation_BTC_cpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -152,13 +154,11 @@ def main():
152154
# dt = (t1 - t0) / 1
153155
# dsa_logger.info(f"BTC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
154156

155-
# Fit line to data using scipy
156157
time_array = np.array(time_array)
157158
d_array = np.array(d_array)
158-
from scipy.optimize import curve_fit
159159

160-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
161-
dsa_logger.info(f"BTC: Fit: {popt}")
160+
a, b, c = fit_timings(d_array, time_array)
161+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
162162

163163

164164
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_JtR_calculation_TBC_cpu.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -153,18 +155,14 @@ def main():
153155
# dt = (t1 - t0) / 1
154156
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
155157

156-
# Fit line to data using scipy
157158
time_array = np.array(time_array)
158-
d_array = np.array(d_array)
159-
from scipy.optimize import curve_fit
160-
161-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
162-
dsa_logger.info(f"TBC: Fit: {popt}")
163-
164159
shard_time_array = np.array(shard_time_array)
160+
d_array = np.array(d_array)
165161

166-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array)
167-
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
162+
a, b, c = fit_timings(d_array, time_array)
163+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
164+
a, b, c = fit_timings(d_array, shard_time_array)
165+
dsa_logger.info(f"Fit (sharded): t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
168166

169167

170168
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_JtR_calculation_TBC_gpu.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cuda'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68

@@ -100,7 +102,7 @@ def main():
100102
t1 = time.time()
101103
dt = (t1 - t0) / 10
102104
dsa_logger.info(f"TBC: J^T.R (Full Avg.): CPU D={D}: {dt}")
103-
d_array.append(dt)
105+
d_array.append(D)
104106
shard_time_array.append(dt)
105107

106108
data = prepare_data(D, Ts=4, Tm=1, Cs=4, Cm=1)
@@ -121,15 +123,11 @@ def main():
121123
# dt = (t1 - t0) / 1
122124
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
123125

124-
# Fit line to data using scipy
125-
d_array = np.array(d_array)
126-
from scipy.optimize import curve_fit
127-
128126
shard_time_array = np.array(shard_time_array)
127+
d_array = np.array(d_array)
129128

130-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array, bounds=([0., 0.5], [np.inf, 1.5]))
131-
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
132-
129+
a, b, c = fit_timings(d_array, shard_time_array)
130+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
133131

134132
if __name__ == '__main__':
135133
main()

dsa2000_cal/benchmarking/calibration/benchmark_R_calculation_BTC_cpu.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -140,11 +142,11 @@ def main():
140142
# Fit line to data using scipy
141143
time_array = np.array(time_array)
142144
d_array = np.array(d_array)
143-
from scipy.optimize import curve_fit
144-
145-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
146-
dsa_logger.info(f"BTC: Fit: {popt}")
145+
time_array = np.array(time_array)
146+
d_array = np.array(d_array)
147147

148+
a, b, c = fit_timings(d_array, time_array)
149+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
148150

149151
if __name__ == '__main__':
150152
main()

dsa2000_cal/benchmarking/calibration/benchmark_R_calculation_TBC_cpu.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -137,18 +139,14 @@ def main():
137139
# dt = (t1 - t0) / 1
138140
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
139141

140-
# Fit line to data using scipy
141142
time_array = np.array(time_array)
142-
d_array = np.array(d_array)
143-
from scipy.optimize import curve_fit
144-
145-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array)
146-
dsa_logger.info(f"TBC: Fit: {popt}")
147-
148143
shard_time_array = np.array(shard_time_array)
144+
d_array = np.array(d_array)
149145

150-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array)
151-
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
146+
a, b, c = fit_timings(d_array, time_array)
147+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
148+
a, b, c = fit_timings(d_array, shard_time_array)
149+
dsa_logger.info(f"Fit (sharded): t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
152150

153151

154152
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_R_calculation_TBC_gpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from dsa2000_common.common.fit_benchmark import fit_timings
5+
46
os.environ['JAX_PLATFORMS'] = 'cuda'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
from jax._src.partition_spec import PartitionSpec
@@ -105,14 +107,12 @@ def main():
105107
# dt = (t1 - t0) / 1
106108
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
107109

108-
# Fit line to data using scipy
109-
from scipy.optimize import curve_fit
110-
111110
shard_time_array = np.array(shard_time_array)
112111
d_array = np.array(d_array)
113112

114-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, shard_time_array, bounds=([0., 0.5], [np.inf, 1.5]))
115-
dsa_logger.info(f"TBC: Fit (sharded): {popt}")
113+
a, b, c = fit_timings(d_array, shard_time_array)
114+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
115+
116116

117117

118118
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_calibration_gpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from dsa2000_common.common.fit_benchmark import fit_timings
4+
35
os.environ['JAX_PLATFORMS'] = 'cuda'
46
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
57

@@ -112,13 +114,11 @@ def main():
112114
# dt = (t1 - t0) / 3
113115
# dsa_logger.info(f"Calibration Single-Iteration Single-CG Step (C=40 w/ reps): CPU D={D}: {dt}")
114116

115-
# Fit line to data using scipy
116117
time_array = np.array(time_array)
117118
d_array = np.array(d_array)
118-
from scipy.optimize import curve_fit
119119

120-
popt, pcov = curve_fit(lambda x, a, b: a * x ** b, d_array, time_array, bounds=([0., 0.5], [np.inf, 1.5]))
121-
dsa_logger.info(f"Fit: {popt}")
120+
a, b, c = fit_timings(d_array, time_array)
121+
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
122122

123123

124124
if __name__ == '__main__':

0 commit comments

Comments
 (0)