Skip to content

Commit c59aafa

Browse files
authored
[fix] match vbeam parameters (#22)
* Changes that seem to help with tolerance? double-check * Add single-transmit test? * Was missing a tukey-alpha=0 that was causing an issue * FIx grid shape typo * Fix the atol rtol issues * Plot the difference map in db instead of fraction * Fix lint * Some cleanup * Some cleanup * Clean up some prints * Minor fix * More cleanup --------- Co-authored-by: Charles Guan <3221512+charlesincharge@users.noreply.github.com>
1 parent ecb8244 commit c59aafa

File tree

5 files changed

+116
-32
lines changed

5 files changed

+116
-32
lines changed

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ preview = true
186186
filterwarnings = [
187187
"error", # make warnings errors
188188
"ignore:.*This will add latency due to CPU<->GPU memory transfers.*:UserWarning",
189+
# vbeam warnings
190+
"ignore:point_position will be overwritten by the scan.:UserWarning",
191+
"ignore:Both point_position and scan are set. Scan will be used.:UserWarning",
189192
]
190193
markers = [
191194
# see conftest.py for default-addition of cuda marker

src/mach/_vis.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Visualization utilities for test-diagnostics."""
22

3-
from copy import copy
43
from pathlib import Path
54
from typing import Optional
65

@@ -95,7 +94,7 @@ def save_debug_figures(
9594
if reference_result is not None:
9695
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
9796

98-
max_value = max(np.max(np.abs(our_img)), np.max(np.abs(ref_img)))
97+
max_value = max(np.max(np.abs(our_img)), np.max(np.abs(ref_img)), 1e-12)
9998

10099
# Our result - convert to dB
101100
our_img_db = db(our_img / max_value)
@@ -131,15 +130,14 @@ def save_debug_figures(
131130
cbar3 = plt.colorbar(im3, ax=axes[1, 0])
132131
cbar3.set_label("Linear")
133132

134-
# Relative difference
135-
rel_diff = np.abs(diff_img) / (np.abs(ref_img) + 1e-10)
136-
cmap = copy(plt.cm.hot)
137-
cmap.set_over("blue", 1.0)
138-
im4 = axes[1, 1].imshow(rel_diff.T, aspect="auto", origin="upper", cmap=cmap, extent=extent, vmin=0, vmax=1)
139-
axes[1, 1].set_title("Relative Difference")
133+
# Relative difference in dB
134+
diff_db = db(diff_img / max_value)
135+
im4 = axes[1, 1].imshow(diff_db.T, aspect="auto", origin="upper", cmap="hot", extent=extent, vmin=-140, vmax=0)
136+
axes[1, 1].set_title("Difference (dB, 0dB = max(ref, our))")
140137
axes[1, 1].set_xlabel("Lateral [cm]")
141138
axes[1, 1].set_ylabel("Depth [cm]")
142-
plt.colorbar(im4, ax=axes[1, 1], extend="max")
139+
cbar4 = plt.colorbar(im4, ax=axes[1, 1])
140+
cbar4.set_label("dB")
143141

144142
else:
145143
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

src/mach/experimental.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def beamform(
7979
for transmit_idx in range(n_transmits):
8080
# Extract single-transmit data
8181
single_channel_data = channel_data[transmit_idx]
82-
single_rx_start_s = rx_start_s[transmit_idx]
8382

8483
# Call single-transmit beamform
8584
_ = kernel.beamform(
@@ -88,7 +87,7 @@ def beamform(
8887
scan_coords_m,
8988
tx_wave_arrivals_s[transmit_idx],
9089
out=out,
91-
rx_start_s=single_rx_start_s,
90+
rx_start_s=rx_start_s,
9291
sampling_freq_hz=sampling_freq_hz,
9392
f_number=f_number,
9493
sound_speed_m_s=sound_speed_m_s,

src/mach/io/uff.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def compute_tx_wave_arrivals_s(
4646
scan_coords_m: Output positions for beamforming (N, 3) in meters
4747
speed_of_sound: Speed of sound in the medium
4848
origin: Wave origin point, defaults to [0, 0, 0]
49+
Can also be parsed from ultrasound_angles_to_cartesian(channel_data.sequence[idx].origin)
4950
xp: Array namespace (optional, will be inferred from scan_coords_m)
5051
5152
Returns:
@@ -159,9 +160,12 @@ def create_beamforming_setup(channel_data, scan, f_number: float = 1.7, xp=None)
159160
# Compute transmit arrivals for all transmits
160161
directions = extract_wave_directions(channel_data.sequence, xp or np)
161162
tx_wave_arrivals_s = compute_tx_wave_arrivals_s(directions, scan_coords_m, speed_of_sound, xp=xp)
163+
# further delay each transmit by the delay of the wave
164+
tx_wave_arrivals_s = tx_wave_arrivals_s + extract_sequence_delays(channel_data.sequence, xp)[:, None]
162165

163-
# Extract timing information for all transmits
164-
rx_delays = extract_sequence_delays(channel_data.sequence, xp)
166+
# Account for initial_time offset (this is how vbeam handles it)
167+
# The initial_time represents when the first sample was acquired relative to t=0
168+
rx_start_s = float(channel_data.initial_time)
165169

166170
return {
167171
"channel_data": signal,
@@ -173,7 +177,7 @@ def create_beamforming_setup(channel_data, scan, f_number: float = 1.7, xp=None)
173177
"sampling_freq_hz": sampling_freq_hz,
174178
"sound_speed_m_s": speed_of_sound,
175179
"modulation_freq_hz": modulation_frequency,
176-
"rx_start_s": rx_delays,
180+
"rx_start_s": rx_start_s,
177181
}
178182

179183

@@ -198,8 +202,7 @@ def create_single_transmit_beamforming_setup(
198202

199203
# Extract single transmit data
200204
single_setup = multi_setup.copy()
201-
single_setup["channel_data"] = multi_setup["channel_data"][wave_index] # Remove transmit dimension
202-
single_setup["tx_wave_arrivals_s"] = multi_setup["tx_wave_arrivals_s"][wave_index] # Remove transmit dimension
203-
single_setup["rx_start_s"] = multi_setup["rx_start_s"][wave_index]
205+
single_setup["channel_data"] = multi_setup["channel_data"][wave_index]
206+
single_setup["tx_wave_arrivals_s"] = multi_setup["tx_wave_arrivals_s"][wave_index]
204207

205208
return single_setup

tests/compare/test_vbeam.py

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from vbeam.scan import LinearScan
3838
from vbeam.wavefront import PlaneWavefront, ReflectedWavefront
3939

40-
from mach import experimental
40+
from mach import experimental, kernel
4141
from mach._vis import save_debug_figures
42-
from mach.io.uff import create_beamforming_setup
42+
from mach.io.uff import create_beamforming_setup, create_single_transmit_beamforming_setup
4343

4444
# ============================================================================
4545
# Fixtures
@@ -178,14 +178,44 @@ def vbeam_setup_uff(
178178

179179

180180
@pytest.fixture(scope="module")
181-
def picmus_phantom_resolution_beamform_kwargs(
181+
def mach_beamform_kwargs(
182182
picmus_phantom_resolution_channel_data: ChannelData, picmus_phantom_resolution_scan: Scan
183183
) -> dict:
184184
"""mach kwargs for UFF data."""
185185
return create_beamforming_setup(
186186
channel_data=picmus_phantom_resolution_channel_data,
187187
scan=picmus_phantom_resolution_scan,
188-
xp=cp if HAS_CUPY else None,
188+
xp=cp if HAS_CUPY else np,
189+
)
190+
191+
192+
# ============================================================================
193+
# Single Transmit Test Fixtures
194+
# ============================================================================
195+
196+
197+
@pytest.fixture(scope="module", params=[0, 1, 10, 37, 74])
198+
def transmit_idx(request):
199+
"""Parametrized transmit index for single-transmit testing."""
200+
return request.param
201+
202+
203+
@pytest.fixture(scope="module")
204+
def vbeam_setup_uff_single_transmit(vbeam_setup_uff: SignalForPointSetup, transmit_idx: int) -> SignalForPointSetup:
205+
"""Create a single-transmit vbeam setup from the full UFF setup."""
206+
return vbeam_setup_uff.slice["transmits", transmit_idx]
207+
208+
209+
@pytest.fixture(scope="module")
210+
def mach_single_transmit_kwargs(
211+
picmus_phantom_resolution_channel_data: ChannelData, picmus_phantom_resolution_scan: Scan, transmit_idx: int
212+
) -> dict:
213+
"""mach kwargs for UFF data."""
214+
return create_single_transmit_beamforming_setup(
215+
channel_data=picmus_phantom_resolution_channel_data,
216+
scan=picmus_phantom_resolution_scan,
217+
wave_index=transmit_idx,
218+
xp=cp if HAS_CUPY else np,
189219
)
190220

191221

@@ -261,13 +291,66 @@ def vbeam_beamform():
261291

262292
@pytest.mark.skipif(not HAS_CUPY, reason="CuPy not available")
263293
@pytest.mark.filterwarnings("ignore:array is not contiguous, rearranging will add latency:UserWarning")
264-
def test_mach_matches_vbeam(
265-
picmus_phantom_resolution_beamform_kwargs, vbeam_setup_uff: SignalForPointSetup, output_dir
294+
def test_mach_matches_vbeam_single_transmit(
295+
mach_single_transmit_kwargs: dict,
296+
vbeam_setup_uff_single_transmit: SignalForPointSetup,
297+
transmit_idx: int,
298+
output_dir,
266299
):
300+
"""Test mach vs vbeam on a single plane wave transmit to isolate core beamforming differences."""
301+
grid_shape = vbeam_setup_uff_single_transmit.scan.shape
302+
303+
# Run mach single-transmit beamforming using kernel.beamform directly
304+
gpu_result = kernel.beamform(**mach_single_transmit_kwargs, tukey_alpha=0.0)
305+
result = cp.asnumpy(gpu_result)
306+
# Reshape to (x, z)
307+
result = result.reshape(grid_shape)
308+
309+
# Verify basic properties
310+
assert np.isfinite(result).all()
311+
312+
# Run vbeam single-transmit beamforming
313+
beamformer = get_das_beamformer(
314+
vbeam_setup_uff_single_transmit,
315+
compensate_for_apodization_overlap=False,
316+
log_compress=False,
317+
scan_convert=False,
318+
)
319+
vbeam_result_jax = beamformer(**vbeam_setup_uff_single_transmit.data).block_until_ready()
320+
vbeam_result = np.asarray(vbeam_result_jax)
321+
322+
# Save debug output if requested
323+
if output_dir is not None:
324+
output_dir = output_dir / "single_transmit_comparison" / f"transmit_{transmit_idx}"
325+
save_debug_figures(
326+
our_result=np.abs(result),
327+
reference_result=np.abs(vbeam_result),
328+
grid_shape=grid_shape,
329+
x_axis=vbeam_setup_uff_single_transmit.scan.x,
330+
z_axis=vbeam_setup_uff_single_transmit.scan.z,
331+
output_dir=output_dir,
332+
test_name=f"single_transmit_{transmit_idx}",
333+
our_label="mach",
334+
reference_label="vbeam",
335+
)
336+
337+
np.testing.assert_allclose(
338+
actual=result,
339+
desired=vbeam_result,
340+
atol=0.01,
341+
rtol=1 / 100,
342+
err_msg=f"mach single transmit {transmit_idx} results do not match vbeam within expected tolerances",
343+
)
344+
345+
346+
@pytest.mark.skipif(not HAS_CUPY, reason="CuPy not available")
347+
@pytest.mark.filterwarnings("ignore:array is not contiguous, rearranging will add latency:UserWarning")
348+
def test_mach_matches_vbeam(mach_beamform_kwargs, vbeam_setup_uff: SignalForPointSetup, output_dir):
267349
"""Validate mach against vbeam output on a PICMUS UFF data file."""
268350
grid_shape = vbeam_setup_uff.scan.shape
269351

270-
gpu_result = experimental.beamform(**picmus_phantom_resolution_beamform_kwargs)
352+
# Match our custom vbeam apodization settings
353+
gpu_result = experimental.beamform(**mach_beamform_kwargs, tukey_alpha=0.0)
271354
result = cp.asnumpy(gpu_result)
272355
# Reshape to (x, z)
273356
result = result.reshape(grid_shape)
@@ -286,7 +369,7 @@ def test_mach_matches_vbeam(
286369
vbeam_result_jax = beamformer(**vbeam_setup_uff.data).block_until_ready()
287370
vbeam_result = np.asarray(vbeam_result_jax)
288371

289-
# Compare magnitudes because we handle the phase slightly differently from vbeam
372+
# Also show magnitude comparison for reference
290373
vbeam_magnitude = np.abs(vbeam_result)
291374
cuda_magnitude = np.abs(result)
292375

@@ -306,14 +389,12 @@ def test_mach_matches_vbeam(
306389
)
307390
print("Saved debug figures to", output_dir)
308391

309-
# Validate mach against vbeam
310-
# TODO: may want to further fine-tune these tolerances
311392
np.testing.assert_allclose(
312-
actual=cuda_magnitude,
313-
desired=vbeam_magnitude,
314-
atol=10,
315-
rtol=0.3,
316-
err_msg="mach results do not match vbeam within expected tolerances",
393+
actual=result,
394+
desired=vbeam_result,
395+
atol=0.01,
396+
rtol=1 / 100,
397+
err_msg="mach complex results do not match vbeam within expected tolerances (with scaling correction)",
317398
)
318399

319400

0 commit comments

Comments
 (0)