Skip to content

Commit bbc0889

Browse files
Use normalisation from paper and add PyTorch notebook example (#9)
- The convergence is very dependent on the data scaling and corresponding hyperparameters, so we are again using the (-2, 2) data scaling for the branch net - Created a notebook that should give a good introduction to DeepONets and in particular the implementation of the wave propagation. Plotting methods done for 2D data.
1 parent 11ff54d commit bbc0889

File tree

76 files changed

+1897
-51
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1897
-51
lines changed

.vscode/launch.json

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,26 @@
1212
"console": "integratedTerminal"
1313
},
1414
{
15-
"name": "Python: Train 3D",
15+
"name": "Python: Train",
1616
"type": "debugpy",
1717
"request": "launch",
1818
"module": "deeponet_acoustics.main_train",
1919
"args": [
2020
"--path_settings",
21-
"json_setups/threeD/settings.json",
21+
"json_setups/threeD/choras_cube.json",
2222
],
2323
"console": "integratedTerminal"
24-
},
24+
},
2525
{
26-
"name": "Python: Inference 3D",
26+
"name": "Python: Inference",
2727
"type": "debugpy",
2828
"request": "launch",
2929
"module": "deeponet_acoustics.main_inference",
3030
"args": [
3131
"--path_settings",
32-
"json_setups/threeD/settings.json",
32+
"json_setups/twoD/rect2x2.json",
3333
"--path_eval_settings",
34-
"json_setups/threeD/settings_eval.json",
34+
"json_setups/twoD/rect2x2_eval.json",
3535
],
3636
"console": "integratedTerminal"
3737
},

deeponet_acoustics/datahandlers/datagenerators.py

Lines changed: 72 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import h5py
1717
import jax.numpy as jnp
1818
import numpy as np
19+
import torch
1920
from torch.utils.data import Dataset
2021

2122
import deeponet_acoustics.datahandlers.io as IO
@@ -160,6 +161,7 @@ def __init__(
160161
data_prune=1,
161162
norm_data=False,
162163
MAXNUM_DATASETS=sys.maxsize,
164+
u_p_range=None,
163165
):
164166
filenames_xdmf = IO.pathsToFileType(data_path, ".xdmf", exclude="rectilinear")
165167
self.normalize_data = norm_data
@@ -202,9 +204,15 @@ def __init__(
202204
self.N = len(self.datasets)
203205

204206
# Calculate min and max pressure values from sampled datasets
205-
self._u_p_min, self._u_p_max = _calculate_u_pressure_minmax(
206-
self.datasets, self.tag_ufield
207-
)
207+
if u_p_range is not None:
208+
self._u_p_min, self._u_p_max = u_p_range
209+
print(
210+
f"Using specified u pressure range: [{self._u_p_min:.4f}, {self._u_p_max:.4f}]"
211+
)
212+
else:
213+
self._u_p_min, self._u_p_max = _calculate_u_pressure_minmax(
214+
self.datasets, self.tag_ufield
215+
)
208216

209217
# --- required abstract properties implemented ---
210218
@property
@@ -264,6 +272,7 @@ def __init__(
264272
data_prune=1,
265273
norm_data=False,
266274
MAXNUM_DATASETS=sys.maxsize,
275+
u_p_range=None,
267276
):
268277
filenamesH5 = IO.pathsToFileType(data_path, ".h5", exclude="rectilinear")
269278
self.data_prune = data_prune
@@ -295,10 +304,13 @@ def __init__(
295304
)
296305
self.tsteps = r[self.tags_field[0]].attrs["time_steps"]
297306
self.tsteps = jnp.array([t for t in self.tsteps if t <= tmax / t_norm])
298-
self.tsteps = self.tsteps * t_norm
307+
self.tsteps = (
308+
self.tsteps * t_norm
309+
) # corresponding to c = 1 for same spatial / temporal resolution
299310

300311
if self.normalize_data:
301312
self.mesh = self.normalize_spatial(self.mesh)
313+
# normalize relative to spatial dimension to keep ratio
302314
self.tsteps = self.normalize_temporal(self.tsteps)
303315

304316
self.tt = np.repeat(self.tsteps, self.mesh.shape[0])
@@ -308,16 +320,21 @@ def __init__(
308320
for i in range(0, min(MAXNUM_DATASETS, len(filenamesH5))):
309321
filename = filenamesH5[i]
310322
if Path(filename).exists():
311-
self.datasets.append(
312-
h5py.File(filename, "r")
313-
) # add file handles and keeps open
323+
# add file handles and keeps open
324+
self.datasets.append(h5py.File(filename, "r"))
314325
else:
315326
print(f"Could not be found (ignoring): {filename}")
316327

317328
# Calculate min and max pressure values from sampled datasets
318-
self._u_p_min, self._u_p_max = _calculate_u_pressure_minmax(
319-
self.datasets, self.tag_ufield
320-
)
329+
if u_p_range is not None:
330+
self._u_p_min, self._u_p_max = u_p_range
331+
print(
332+
f"Using specified u pressure range: [{self._u_p_min:.4f}, {self._u_p_max:.4f}]"
333+
)
334+
else:
335+
self._u_p_min, self._u_p_max = _calculate_u_pressure_minmax(
336+
self.datasets, self.tag_ufield
337+
)
321338

322339
# --- required abstract properties implemented ---
323340
@property
@@ -332,8 +349,11 @@ def P(self):
332349

333350
@property
334351
def xxyyzztt(self):
335-
xxyyzz = np.tile(self.mesh, (len(self.tsteps), 1))
336-
return np.hstack((xxyyzz, self.tt.reshape(-1, 1)))
352+
return np.hstack((self.xxyyzz, self.tt.reshape(-1, 1)))
353+
354+
@property
355+
def xxyyzz(self):
356+
return np.tile(self.mesh, (len(self.tsteps), 1))
337357

338358
def normalize_spatial(self, data):
339359
return _normalize_spatial(data, self.xmin, self.xmax)
@@ -412,6 +432,7 @@ def __init__(
412432
flatten_ic: bool = True,
413433
data_prune: int = 1,
414434
norm_data: bool = False,
435+
u_p_range=None,
415436
) -> None:
416437
self.data_prune = data_prune
417438
self._normalize_data = norm_data
@@ -478,9 +499,15 @@ def __init__(
478499
)
479500

480501
# Calculate min and max pressure values from generated datasets
481-
self._u_p_min, self._u_p_max = _calculate_u_pressure_minmax(
482-
self.datasets, self.tag_ufield, max_samples=self.N
483-
)
502+
if u_p_range is not None:
503+
self._u_p_min, self._u_p_max = u_p_range
504+
print(
505+
f"Using specified u pressure range: [{self._u_p_min:.4f}, {self._u_p_max:.4f}]"
506+
)
507+
else:
508+
self._u_p_min, self._u_p_max = _calculate_u_pressure_minmax(
509+
self.datasets, self.tag_ufield, max_samples=self.N
510+
)
484511

485512
self.tt = np.repeat(self.tsteps, self.mesh.shape[0])
486513

@@ -509,8 +536,11 @@ def normalize_data(self) -> bool:
509536
@property
510537
def xxyyzztt(self) -> np.ndarray[float]:
511538
"""Spatio-temporal coordinates stacked as [x, y, z, t]."""
512-
xxyyzz = np.tile(self.mesh, (len(self.tsteps), 1))
513-
return np.hstack((xxyyzz, self.tt.reshape(-1, 1)))
539+
return np.hstack((self.xxyyzz, self.tt.reshape(-1, 1)))
540+
541+
@property
542+
def xxyyzz(self) -> np.ndarray[float]:
543+
return np.tile(self.mesh, (len(self.tsteps), 1))
514544

515545
def normalize_spatial(self, data: np.ndarray[float]) -> np.ndarray[float]:
516546
return _normalize_spatial(data, self._xmin, self._xmax)
@@ -606,7 +636,7 @@ def __getitem__(self, idx):
606636
0:num_tsteps, :: self.data.data_prune
607637
].flatten()[indxs_coord]
608638
elif self.data.simulationDataType == SimulationDataType.XDMF:
609-
s = np.empty((self.P), dtype=jnp.float32)
639+
s = np.empty((self.P), dtype=np.float32)
610640
for j in range(num_tsteps):
611641
s[j * self.data.P_mesh : (j + 1) * self.P_mesh] = dataset[
612642
self.data.tags_field[j]
@@ -615,7 +645,9 @@ def __getitem__(self, idx):
615645
elif self.data.simulationDataType == SimulationDataType.SOURCE_ONLY:
616646
s = []
617647
else:
618-
raise Exception("Data format unknown: should be H5COMPACT or XDMF")
648+
raise Exception(
649+
"Data format unknown: should be H5COMPACT, XDMF or SOURCE_ONLY"
650+
)
619651

620652
# normalize
621653
x0 = (
@@ -624,19 +656,38 @@ def __getitem__(self, idx):
624656
else []
625657
)
626658

627-
inputs = jnp.asarray(u), jnp.asarray(y)
628-
return inputs, jnp.asarray(s), indxs_coord, x0
659+
inputs = np.asarray(u), np.asarray(y)
660+
return inputs, np.asarray(s), indxs_coord, x0
629661

630662

631663
def get_number_of_sources(data_path: str):
632664
return len(IO.pathsToFileType(data_path, ".h5", exclude="rectilinear"))
633665

634666

635667
def numpy_collate(batch):
668+
"""Collate function for JAX - converts batches to JAX arrays."""
636669
if isinstance(batch[0], np.ndarray):
637670
return jnp.stack(batch)
638671
elif isinstance(batch[0], (tuple, list)):
639672
transposed = zip(*batch)
640673
return [numpy_collate(samples) for samples in transposed]
641674
else:
642675
return np.array(batch)
676+
677+
678+
def pytorch_collate(batch):
679+
"""Collate function for PyTorch - converts batches to PyTorch tensors.
680+
681+
Use this collator with PyTorch DataLoader for educational notebooks:
682+
DataLoader(dataset, batch_size=2, collate_fn=pytorch_collate)
683+
"""
684+
685+
if isinstance(batch[0], np.ndarray):
686+
return torch.from_numpy(np.stack(batch)).float()
687+
elif isinstance(batch[0], jnp.ndarray):
688+
return torch.from_numpy(np.array(batch)).float()
689+
elif isinstance(batch[0], (tuple, list)):
690+
transposed = zip(*batch)
691+
return [pytorch_collate(samples) for samples in transposed]
692+
else:
693+
return torch.tensor(batch).float()

deeponet_acoustics/end2end/inference.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,9 @@ def inference(
133133
model.plotLosses(settings.dirs.figs_dir)
134134

135135
tdim = len(metadata.tsteps)
136-
xxyyzztt = metadata.xxyyzztt
137-
y_in = y_feat_fn(xxyyzztt)
136+
y_in = y_feat_fn(metadata.xxyyzztt)
138137

139-
xxyyzz_phys = metadata.denormalize_spatial(xxyyzztt[:, 0:3])
138+
xxyyzz_phys = metadata.denormalize_spatial(metadata.xxyyzz)
140139
mesh_phys = metadata.denormalize_spatial(metadata.mesh)
141140
tsteps_phys = metadata.denormalize_temporal(metadata.tsteps / c_phys)
142141

@@ -196,9 +195,15 @@ def inference(
196195
r0_list_norm = metadata.normalize_spatial(r0_list[i_src])
197196

198197
(u_test_i, *_), s_test_i, _, x0 = dataset[i_src]
199-
200-
x0 = metadata.denormalize_spatial(x0)
201-
x0_srcs.append(x0)
198+
if len(x0) > 0:
199+
x0 = metadata.denormalize_spatial(x0)
200+
x0_srcs.append(x0)
201+
else:
202+
print(
203+
"Warning: test data does not have source position data - setting index as coordinate"
204+
)
205+
x0 = i_src
206+
x0_srcs.append([x0])
202207

203208
y_rcvs = np.repeat(np.array(r0_list_norm), len(metadata.tsteps), axis=0)
204209
tsteps_rcvs = np.tile(metadata.tsteps, len(r0_list_norm))

deeponet_acoustics/end2end/train.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def train(settings_dict: dict[str, Any]):
6565
t_norm=c_phys,
6666
norm_data=settings.normalize_data,
6767
flatten_ic=flatten_ic,
68+
u_p_range=(-2.0, 2.0),
6869
)
6970
dataset = DatasetStreamer(
7071
metadata, training.batch_size_coord, y_feat_extract_fn=y_feat_fn
@@ -75,6 +76,7 @@ def train(settings_dict: dict[str, Any]):
7576
t_norm=c_phys,
7677
norm_data=settings.normalize_data,
7778
flatten_ic=flatten_ic,
79+
u_p_range=(-2.0, 2.0),
7880
)
7981
dataset_val = DatasetStreamer(
8082
metadata_val, training.batch_size_coord, y_feat_extract_fn=y_feat_fn

deeponet_acoustics/models/deeponet.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,13 +219,13 @@ def train(
219219
i += 1
220220

221221
if do_timings:
222-
timer.startTiming("backprop") if do_timings else None
222+
timer.startTiming("backprop")
223223
self.params, self.opt_state, _ = self.step(
224224
self.params, self.opt_state, data_batch
225225
)
226-
jax.block_until_ready(self.params) if do_timings else None
227-
jax.block_until_ready(self.opt_state) if do_timings else None
228-
timer.endTiming("backprop") if do_timings else None
226+
jax.block_until_ready(self.params)
227+
jax.block_until_ready(self.opt_state)
228+
timer.endTiming("backprop")
229229

230230
timer.writeTimings(
231231
{
@@ -235,8 +235,8 @@ def train(
235235
}
236236
)
237237
timer.resetTimings()
238-
timer.startTiming("total_iter") if do_timings else None
239-
timer.startTiming("dataloader") if do_timings else None
238+
timer.startTiming("total_iter")
239+
timer.startTiming("dataloader")
240240
else:
241241
self.params, self.opt_state, _ = self.step(
242242
self.params, self.opt_state, data_batch

deeponet_acoustics/scripts/convertH5/split_2D_by_source.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def split_2d_by_source_position(input_file: Path, output_path: Path):
8888
x0_srcs = None
8989

9090
print(f"Converting {num_sources} source positions from 2D to 3D format...")
91+
92+
# Compute denormalized values once for all sources
93+
c_phys = float(pressure_attrs["c_phys"])
94+
fmax_normalized = float(pressure_attrs["fmax"])
95+
fmax_phys = round(fmax_normalized * c_phys, 1)
96+
dt_normalized = float(pressure_attrs["dt"])
97+
dt_phys = dt_normalized / c_phys
98+
99+
# Update pressure_attrs with denormalized values for use in root metadata
100+
pressure_attrs_denormalized = pressure_attrs.copy()
101+
pressure_attrs_denormalized["fmax"] = fmax_phys
102+
pressure_attrs_denormalized["dt"] = dt_phys
103+
91104
src_pos_2d = None
92105

93106
for src_idx in range(num_sources):
@@ -120,10 +133,12 @@ def split_2d_by_source_position(input_file: Path, output_path: Path):
120133
"pressures", data=src_pressures
121134
)
122135

123-
# Add time_steps attribute to pressures (from t dataset)
124-
pressures_dataset.attrs["time_steps"] = t.astype(np.float32)
136+
# Denormalize time_steps by dividing by physical speed of sound
137+
time_steps_phys = (t / c_phys).astype(np.float32)
138+
pressures_dataset.attrs["time_steps"] = time_steps_phys
139+
125140
# Note: Adding all pressure attributes that weren't in original 3D data
126-
for attr_name, attr_value in pressure_attrs.items():
141+
for attr_name, attr_value in pressure_attrs_denormalized.items():
127142
if attr_name != "time_steps": # Avoid duplicate
128143
pressures_dataset.attrs[attr_name] = attr_value
129144

@@ -147,13 +162,15 @@ def split_2d_by_source_position(input_file: Path, output_path: Path):
147162
for attr_name, attr_value in upressure_attrs.items():
148163
upressures_dataset.attrs[attr_name] = attr_value
149164

150-
# Create metadata JSON file in source folder
151-
create_metadata_json(src_folder, pressure_attrs, src_pos_2d)
165+
# Create metadata JSON file in source folder using corrected attributes
166+
create_metadata_json(
167+
src_folder, dict(pressures_dataset.attrs), src_pos_2d
168+
)
152169

153170
print(f"Created {output_file}")
154171

155172
# Create metadata JSON file in root directory as well
156-
create_metadata_json(output_path, pressure_attrs, None)
173+
create_metadata_json(output_path, pressure_attrs_denormalized, None)
157174

158175
print(f"Conversion complete! Created {num_sources} files in {output_path}")
159176

deeponet_acoustics/json_setups/evaluations/loss_plot_ids.py renamed to json_setups/loss_plot_ids.py

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)