Skip to content

Commit a7b97db

Browse files
committed
removing CPU counterpart from rescale to int and updating tests
1 parent 67a3c83 commit a7b97db

File tree

2 files changed

+25
-72
lines changed

2 files changed

+25
-72
lines changed

httomolibgpu/misc/rescale.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,28 +30,27 @@
3030

3131
from httomolibgpu.misc.supp_func import data_checker
3232

33-
from sofia.core.regularisers import rescale_to_int_C
3433

3534
__all__ = [
3635
"rescale_to_int",
3736
]
3837

3938

4039
def rescale_to_int(
41-
data: Union[np.ndarray, cp.ndarray],
40+
data: cp.ndarray,
4241
perc_range_min: float = 0.0,
4342
perc_range_max: float = 100.0,
4443
bits: Literal[8, 16, 32] = 8,
4544
glob_stats: Optional[Tuple[float, float, float, int]] = None,
46-
) -> Union[np.ndarray, cp.ndarray]:
45+
) -> cp.ndarray:
4746
"""
4847
Rescales the data given as float32 type and converts it into the range of an unsigned integer type
4948
with the given number of bits. For more detailed information and examples, see :ref:`method_rescale_to_int`.
5049
5150
Parameters
5251
----------
53-
data : Union[np.ndarray, cp.ndarray]
54-
Input data as a numpy or cupy array (the function is cpu-gpu agnostic)
52+
data : cp.ndarray
53+
Input data as a cupy array
5554
perc_range_min: float, optional
5655
The lower cutoff point in the input data, in percent of the data range (defaults to 0).
5756
The lower bound is computed as min + perc_range_min/100*(max-min)
@@ -71,7 +70,7 @@ def rescale_to_int(
7170
7271
Returns
7372
-------
74-
Union[np.ndarray, cp.ndarray]
73+
cp.ndarray
7574
The original data, clipped to the range specified with the perc_range_min and
7675
perc_range_max, and scaled to the full range of the output integer type
7776
"""
@@ -84,18 +83,13 @@ def rescale_to_int(
8483

8584
data = data_checker(data, verbosity=True, method_name="rescale_to_int")
8685

87-
if cupy_run:
88-
xp = cp.get_array_module(data)
89-
else:
90-
import numpy as xp
91-
9286
# get the min and max integer values of the output type
93-
output_min = xp.iinfo(output_dtype).min
94-
output_max = xp.iinfo(output_dtype).max
87+
output_min = cp.iinfo(output_dtype).min
88+
output_max = cp.iinfo(output_dtype).max
9589

9690
if not isinstance(glob_stats, tuple):
97-
min_value = float(xp.min(data))
98-
max_value = float(xp.max(data))
91+
min_value = float(cp.min(data))
92+
max_value = float(cp.max(data))
9993
else:
10094
min_value = glob_stats[0]
10195
max_value = glob_stats[1]
@@ -104,33 +98,21 @@ def rescale_to_int(
10498
input_min = (perc_range_min * (range_intensity) / 100) + min_value
10599
input_max = (perc_range_max * (range_intensity) / 100) + min_value
106100

101+
factor = cp.float32(1.0)
107102
if (input_max - input_min) != 0.0:
108-
factor = xp.float32((output_max - output_min) / (input_max - input_min))
109-
else:
110-
factor = 1.0
111-
112-
if xp.__name__ == "numpy":
113-
if input_max == pow(2, 32):
114-
input_max -= 1
115-
res = rescale_to_int_C(data, input_min, input_max, factor)
116-
# res = np.copy(data.astype(float))
117-
# res[data.astype(float) < input_min] = int(input_min)
118-
# res[data.astype(float) > input_max] = int(input_max)
119-
# res -= input_min
120-
# res *= factor
121-
# res = output_dtype(res)
122-
else:
123-
res = xp.empty(data.shape, dtype=output_dtype)
124-
rescale_kernel = cp.ElementwiseKernel(
125-
"T x, raw T input_min, raw T input_max, raw T factor",
126-
"O out",
127-
"""
128-
T x_clean = isnan(x) || isinf(x) ? T(0) : x;
129-
T x_clipped = x_clean < input_min ? input_min : (x_clean > input_max ? input_max : x_clean);
130-
T x_rebased = x_clipped - input_min;
131-
out = O(x_rebased * factor);
132-
""",
133-
"rescale_to_int",
134-
)
135-
rescale_kernel(data, input_min, input_max, factor, res)
103+
factor = cp.float32((output_max - output_min) / (input_max - input_min))
104+
105+
res = cp.empty(data.shape, dtype=output_dtype)
106+
rescale_kernel = cp.ElementwiseKernel(
107+
"T x, raw T input_min, raw T input_max, raw T factor",
108+
"O out",
109+
"""
110+
T x_clean = isnan(x) || isinf(x) ? T(0) : x;
111+
T x_clipped = x_clean < input_min ? input_min : (x_clean > input_max ? input_max : x_clean);
112+
T x_rebased = x_clipped - input_min;
113+
out = O(x_rebased * factor);
114+
""",
115+
"rescale_to_int",
116+
)
117+
rescale_kernel(data, input_min, input_max, factor, res)
136118
return res

tests/test_misc/test_rescale.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,11 @@ def test_rescale_no_change():
1414
res_dev = rescale_to_int(
1515
data_dev, bits=8, glob_stats=(0.0, 255.0, 100.0, data.size)
1616
)
17-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(0.0, 255.0, 100.0, data.size))
1817

1918
res = cp.asnumpy(res_dev).astype(np.float32)
2019

2120
assert res_dev.dtype == np.uint8
22-
assert res_cpu.dtype == np.uint8
2321
np.testing.assert_array_equal(res, data)
24-
np.testing.assert_array_equal(res, res_cpu)
2522

2623

2724
@pytest.mark.parametrize("bits", [8, 16, 32])
@@ -31,28 +28,22 @@ def test_rescale_no_change_no_stats(bits: Literal[8, 16, 32]):
3128
data[13, 1] = (2**bits) - 1
3229
data_dev = cp.asarray(data)
3330
res_dev = rescale_to_int(data_dev, bits=bits)
34-
res_cpu = rescale_to_int(data, bits=bits)
3531

3632
res_dev_float32 = cp.asnumpy(res_dev).astype(np.float32)
3733

3834
assert res_dev.dtype.itemsize == bits // 8
3935
np.testing.assert_array_equal(res_dev_float32, data)
40-
assert res_cpu.dtype.itemsize == bits // 8
41-
res_cpu_float32 = np.float32(res_cpu)
42-
np.testing.assert_array_equal(res_dev_float32, res_cpu_float32)
4336

4437

4538
def test_rescale_double():
4639
data = np.ones((30, 50), dtype=np.float32)
4740

4841
data_dev = cp.asarray(data)
4942
res_dev = rescale_to_int(data_dev, bits=8, glob_stats=(0, 2, 100, data.size))
50-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(0, 2, 100, data.size))
5143

5244
res = cp.asnumpy(res_dev).astype(np.float32)
5345

5446
np.testing.assert_array_almost_equal(res, 127.0)
55-
np.testing.assert_array_almost_equal(res_cpu, 127.0)
5647

5748

5849
def test_rescale_handles_nan_inf():
@@ -63,25 +54,21 @@ def test_rescale_handles_nan_inf():
6354

6455
data_dev = cp.asarray(data)
6556
res_dev = rescale_to_int(data_dev, bits=8, glob_stats=(0, 2, 100, data.size))
66-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(0, 2, 100, data.size))
6757

6858
res = cp.asnumpy(res_dev).astype(np.float32)
6959

7060
np.testing.assert_array_equal(res[0, 0:3], 0.0)
71-
np.testing.assert_array_equal(res_cpu[0, 0:3], 0.0)
7261

7362

7463
def test_rescale_double_offset():
7564
data = np.ones((30, 50), dtype=np.float32) + 10
7665

7766
data_dev = cp.asarray(data)
7867
res_dev = rescale_to_int(data_dev, bits=8, glob_stats=(10, 12, 100, data.size))
79-
res_cpu = rescale_to_int(data, bits=8, glob_stats=(10, 12, 100, data.size))
8068

8169
res = cp.asnumpy(res_dev).astype(np.float32)
8270

8371
np.testing.assert_array_almost_equal(res, 127.0)
84-
np.testing.assert_array_almost_equal(res_cpu, 127.0)
8572

8673

8774
@pytest.mark.parametrize("bits", [8, 16])
@@ -99,14 +86,6 @@ def test_rescale_double_offset_min_percentage(bits: Literal[8, 16, 32]):
9986
perc_range_max=90.0,
10087
)
10188

102-
res_cpu = rescale_to_int(
103-
data,
104-
bits=bits,
105-
glob_stats=(10, 20, 100, data.size),
106-
perc_range_min=10.0,
107-
perc_range_max=90.0,
108-
)
109-
11089
res = cp.asnumpy(res_dev).astype(np.float32)
11190

11291
max = (2**bits) - 1
@@ -116,22 +95,14 @@ def test_rescale_double_offset_min_percentage(bits: Literal[8, 16, 32]):
11695
assert res[0, 0] == 0.0
11796
assert res[0, 1] == max
11897

119-
res_cpu = res_cpu.astype(np.float32)
120-
np.testing.assert_array_almost_equal(res_cpu[1:, :], num)
121-
assert res_cpu[0, 0] == 0.0
122-
assert res_cpu[0, 1] == max
123-
12498

12599
def test_tomo_data_scale(data):
126100
data_cpu = data.get()
127101
res_dev = rescale_to_int(
128102
data.astype(cp.float32), perc_range_min=10, perc_range_max=90, bits=8
129103
)
130-
res_cpu = rescale_to_int(data_cpu, perc_range_min=10, perc_range_max=90, bits=8)
131104
res = res_dev.get()
132105
assert res_dev.dtype == np.uint8
133-
assert res_dev.dtype == np.uint8
134-
np.testing.assert_array_equal(res_cpu, res)
135106

136107

137108
@pytest.mark.perf

0 commit comments

Comments
 (0)