Skip to content

Commit 5e39af6

Browse files
committed
Reset to main and add profiling
1 parent 05faeb2 commit 5e39af6

File tree

2 files changed

+58
-47
lines changed

2 files changed

+58
-47
lines changed

src/mdio/core/grid.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,24 +134,19 @@ def build_map(self, index_headers: HeaderArray) -> None:
134134
# Process headers in batches
135135
for start in range(0, total_live_traces, batch_size):
136136
end = min(start + batch_size, total_live_traces)
137-
138-
# 1) build your per-dimension index arrays
139-
live_dim_indices = [
140-
np.searchsorted(dim, index_headers[dim.name][start:end])
141-
.astype(np.uint32)
142-
for dim in self.dims[:-1]
143-
]
144-
145-
# 2) flatten to 1D indices
146-
flat_idx = np.ravel_multi_index(tuple(live_dim_indices), dims=self.map.shape)
147-
148-
# 3) write into flattened views
149-
flat_map = self.map.reshape(-1)
150-
flat_mask = self.live_mask.reshape(-1)
151-
trace_indices = np.arange(start, end, dtype=flat_map.dtype)
152-
153-
flat_map[flat_idx] = trace_indices
154-
flat_mask[flat_idx] = True
137+
live_dim_indices = []
138+
139+
# Compute indices for the batch
140+
for dim in self.dims[:-1]:
141+
dim_hdr = index_headers[dim.name][start:end]
142+
indices = np.searchsorted(dim, dim_hdr).astype(np.uint32)
143+
live_dim_indices.append(indices)
144+
live_dim_indices = tuple(live_dim_indices)
145+
146+
# Assign trace indices
147+
trace_indices = np.arange(start, end, dtype=np.uint64)
148+
self.map.vindex[live_dim_indices] = trace_indices
149+
self.live_mask.vindex[live_dim_indices] = True
155150

156151

157152
class GridSerializer(Serializer):

src/mdio/segy/_workers.py

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616

1717
from mdio.core import Grid
1818

19+
import os
20+
import cProfile
21+
import pstats
22+
1923

2024
def header_scan_worker(segy_file: SegyFile, trace_range: tuple[int, int]) -> HeaderArray:
2125
"""Header scan worker.
@@ -77,45 +81,57 @@ def trace_worker(
7781
Returns:
7882
Partial statistics for chunk, or None
7983
"""
80-
# Special case where there are no traces inside chunk.
81-
live_subset = grid.live_mask[chunk_indices[:-1]]
8284

83-
if np.count_nonzero(live_subset) == 0:
84-
return None
85+
profiler = cProfile.Profile()
86+
profiler.enable()
87+
try:
88+
89+
# Special case where there are no traces inside chunk.
90+
live_subset = grid.live_mask[chunk_indices[:-1]]
91+
92+
if np.count_nonzero(live_subset) == 0:
93+
return None
8594

86-
# Let's get trace numbers from grid map using the chunk indices.
87-
seq_trace_indices = grid.map[chunk_indices[:-1]]
95+
# Let's get trace numbers from grid map using the chunk indices.
96+
seq_trace_indices = grid.map[chunk_indices[:-1]]
8897

89-
tmp_data = np.zeros(seq_trace_indices.shape + (grid.shape[-1],), dtype=data_array.dtype)
90-
tmp_metadata = np.zeros(seq_trace_indices.shape, dtype=metadata_array.dtype)
98+
tmp_data = np.zeros(seq_trace_indices.shape + (grid.shape[-1],), dtype=data_array.dtype)
99+
tmp_metadata = np.zeros(seq_trace_indices.shape, dtype=metadata_array.dtype)
91100

92-
del grid # To save some memory
101+
del grid # To save some memory
93102

94-
# Read headers and traces for block
95-
valid_indices = seq_trace_indices[live_subset]
103+
# Read headers and traces for block
104+
valid_indices = seq_trace_indices[live_subset]
96105

97-
traces = segy_file.trace[valid_indices.tolist()]
98-
headers, samples = traces["header"], traces["data"]
106+
traces = segy_file.trace[valid_indices.tolist()]
107+
headers, samples = traces["header"], traces["data"]
99108

100-
tmp_metadata[live_subset] = headers.view(tmp_metadata.dtype)
101-
tmp_data[live_subset] = samples
109+
tmp_metadata[live_subset] = headers.view(tmp_metadata.dtype)
110+
tmp_data[live_subset] = samples
102111

103-
# Flush metadata to zarr
104-
metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
112+
# Flush metadata to zarr
113+
metadata_array.set_basic_selection(selection=chunk_indices[:-1], value=tmp_metadata)
105114

106-
nonzero_mask = samples != 0
107-
nonzero_count = nonzero_mask.sum(dtype="uint32")
115+
nonzero_mask = samples != 0
116+
nonzero_count = nonzero_mask.sum(dtype="uint32")
108117

109-
if nonzero_count == 0:
110-
return None
118+
if nonzero_count == 0:
119+
return None
111120

112-
data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
121+
data_array.set_basic_selection(selection=chunk_indices, value=tmp_data)
113122

114-
# Calculate statistics
115-
tmp_data = samples[nonzero_mask]
116-
chunk_sum = tmp_data.sum(dtype="float64")
117-
chunk_sum_squares = np.square(tmp_data, dtype="float64").sum()
118-
min_val = tmp_data.min()
119-
max_val = tmp_data.max()
123+
# Calculate statistics
124+
tmp_data = samples[nonzero_mask]
125+
chunk_sum = tmp_data.sum(dtype="float64")
126+
chunk_sum_squares = np.square(tmp_data, dtype="float64").sum()
127+
min_val = tmp_data.min()
128+
max_val = tmp_data.max()
120129

121-
return nonzero_count, chunk_sum, chunk_sum_squares, min_val, max_val
130+
return nonzero_count, chunk_sum, chunk_sum_squares, min_val, max_val
131+
finally:
132+
profiler.disable()
133+
pid = os.getpid()
134+
profile_path = f"/tmp/trace_worker_profile_{pid}.prof"
135+
with open(profile_path, "w") as f:
136+
ps = pstats.Stats(profiler, stream=f)
137+
ps.strip_dirs().sort_stats("cumulative").print_stats()

0 commit comments

Comments
 (0)