Skip to content

Commit 78c8f8a

Browse files
committed
Begin debugging grid sparsity OOM issues.
1 parent 8e98aaf commit 78c8f8a

File tree

4 files changed

+241
-55
lines changed

4 files changed

+241
-55
lines changed

src/mdio/converters/segy.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,38 +375,47 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
375375
storage_options_input = storage_options_input or {}
376376
storage_options_output = storage_options_output or {}
377377

378+
print("Opening SEG-Y...")
379+
378380
# Open SEG-Y with MDIO's SegySpec. Endianness will be inferred.
379381
mdio_spec = mdio_segy_spec()
382+
print("MDIO spec created")
380383
segy_settings = SegySettings(storage_options=storage_options_input)
381384
segy = SegyFile(url=segy_path, spec=mdio_spec, settings=segy_settings)
385+
print("SEG-Y file opened")
382386

383387
text_header = segy.text_header
384388
binary_header = segy.binary_header
385389
num_traces = segy.num_traces
386-
390+
print("num_traces", num_traces)
387391
# Index the dataset using a spec that interprets the user provided index headers.
388392
index_fields = []
389393
for name, byte, format_ in zip(index_names, index_bytes, index_types, strict=True):
390394
index_fields.append(HeaderField(name=name, byte=byte, format=format_))
391395
mdio_spec_grid = mdio_spec.customize(trace_header_fields=index_fields)
392396
segy_grid = SegyFile(url=segy_path, spec=mdio_spec_grid, settings=segy_settings)
393-
397+
print("SEGY grid created")
394398
dimensions, chunksize, index_headers = get_grid_plan(
395399
segy_file=segy_grid,
396400
return_headers=True,
397401
chunksize=chunksize,
398402
grid_overrides=grid_overrides,
399403
)
404+
print("grid plan created")
400405
grid = Grid(dims=dimensions)
406+
print("grid created")
401407
grid_density_qc(grid, num_traces)
408+
print("grid density qc done")
402409
grid.build_map(index_headers)
403410

404411
# Check grid validity by ensuring every trace's header-index is within dimension bounds
405412
valid_mask = np.ones(grid.num_traces, dtype=bool)
413+
print("valid_mask shape", valid_mask.shape)
406414
for d_idx in range(len(grid.header_index_arrays)):
407415
coords = grid.header_index_arrays[d_idx]
408416
valid_mask &= coords < grid.shape[d_idx]
409417
valid_count = int(np.count_nonzero(valid_mask))
418+
print("valid_count", valid_count)
410419
if valid_count != num_traces:
411420
for dim_name in grid.dim_names:
412421
dim_min = grid.get_min(dim_name)
@@ -417,6 +426,8 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
417426

418427
import gc
419428

429+
# raise Exception("Stop here")
430+
420431
del valid_mask
421432
gc.collect()
422433

@@ -453,6 +464,8 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
453464
)
454465
config = MDIOCreateConfig(path=mdio_path_or_buffer, grid=grid, variables=[var_conf])
455466

467+
print("Creating empty...")
468+
456469
root_group = create_empty(
457470
config,
458471
overwrite=overwrite,
@@ -464,7 +477,10 @@ def segy_to_mdio( # noqa: PLR0913, PLR0915, PLR0912
464477
data_array = data_group[f"chunked_{suffix}"]
465478
header_array = meta_group[f"chunked_{suffix}_trace_headers"]
466479

480+
print("Creating live mask...")
481+
467482
live_mask_array = meta_group["live_mask"]
483+
print(live_mask_array.shape)
468484
# 'live_mask_array' has the same first N–1 dims as 'grid.shape[:-1]'
469485
# Build a ChunkIterator over the live_mask (no sample axis)
470486
from mdio.core.indexing import ChunkIterator

src/mdio/segy/geometry.py

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,22 @@ def create_counter(
202202
total_depth: int,
203203
unique_headers: dict[str, NDArray],
204204
header_names: list[str],
205-
) -> dict[str, dict]:
206-
"""Helper function to create dictionary tree for counting trace key for auto index."""
207-
if depth == total_depth:
208-
return 0
209-
210-
counter = {}
211-
212-
header_key = header_names[depth]
213-
for header in unique_headers[header_key]:
214-
counter[header] = create_counter(depth + 1, total_depth, unique_headers, header_names)
215-
216-
return counter
205+
) -> dict[tuple, int]:
206+
"""Helper function to create flat counter dictionary for counting trace keys for auto index.
207+
208+
This is a memory-efficient version that returns an empty dict since we now process
209+
traces directly in create_trace_index without pre-allocating the counter structure.
210+
211+
Args:
212+
depth: Current recursion depth (unused in new implementation)
213+
total_depth: Total depth of headers (unused in new implementation)
214+
unique_headers: Dictionary of unique header values (unused in new implementation)
215+
header_names: List of header names (unused in new implementation)
216+
217+
Returns:
218+
Empty dictionary - actual counting happens in create_trace_index
219+
"""
220+
return {}
217221

218222

219223
def create_trace_index(
@@ -223,63 +227,88 @@ def create_trace_index(
223227
header_names: list[str],
224228
dtype: DTypeLike = np.int16,
225229
) -> NDArray | None:
226-
"""Update dictionary counter tree for counting trace key for auto index."""
230+
"""Memory-efficient trace index creation that processes traces in a single pass.
231+
232+
Args:
233+
depth: Number of header dimensions to process
234+
counter: Counter dictionary (unused in new implementation)
235+
index_headers: numpy array with index headers
236+
header_names: List of header field names
237+
dtype: numpy type for value of created trace header
238+
239+
Returns:
240+
HeaderArray with added 'trace' field containing trace indices, or None if depth is 0
241+
"""
227242
if depth == 0:
228243
# If there's no hierarchical depth, no tracing needed.
229244
return None
230245

231-
# Add index header
246+
# Add trace field
232247
trace_no_field = np.zeros(index_headers.shape, dtype=dtype)
233248
index_headers = rfn.append_fields(index_headers, "trace", trace_no_field, usemask=False)
234-
235-
# Extract the relevant columns upfront
236-
headers = [index_headers[name] for name in header_names[:depth]]
237-
for idx, idx_values in enumerate(zip(*headers, strict=True)):
238-
if depth == 1:
239-
counter[idx_values[0]] += 1
240-
index_headers["trace"][idx] = counter[idx_values[0]]
241-
else:
242-
sub_counter = counter
243-
for idx_value in idx_values[:-1]:
244-
sub_counter = sub_counter[idx_value]
245-
sub_counter[idx_values[-1]] += 1
246-
index_headers["trace"][idx] = sub_counter[idx_values[-1]]
247-
249+
250+
# Use a flat dictionary with tuple keys instead of nested dictionaries
251+
# This avoids pre-allocating memory for all possible combinations
252+
flat_counter = {}
253+
254+
# Only use the first 'depth' header names
255+
relevant_header_names = header_names[:depth]
256+
257+
# Process each trace in a single pass
258+
for idx in range(len(index_headers)):
259+
# Create tuple key from header values for this trace
260+
key = tuple(index_headers[name][idx] for name in relevant_header_names)
261+
262+
# Increment counter for this combination and assign trace number
263+
flat_counter[key] = flat_counter.get(key, 0) + 1
264+
index_headers["trace"][idx] = flat_counter[key]
265+
248266
return index_headers
249267

250268

251269
def analyze_non_indexed_headers(index_headers: HeaderArray, dtype: DTypeLike = np.int16) -> NDArray:
252270
"""Check input headers for SEG-Y input to help determine geometry.
253271
254-
This function reads in trace_qc_count headers and finds the unique cable values. Then, it
255-
checks to make sure channel numbers for different cables do not overlap.
272+
This function reads in trace_qc_count headers and creates trace indices efficiently.
273+
Uses a memory-efficient approach that doesn't pre-allocate large nested dictionaries.
256274
257275
Args:
258276
index_headers: numpy array with index headers
259277
dtype: numpy type for value of created trace header.
260278
261279
Returns:
262-
Dict container header name as key and numpy array of values as value
280+
HeaderArray with added 'trace' field containing trace indices
263281
"""
264-
# Find unique cable ids
265282
t_start = time.perf_counter()
266-
unique_headers = {}
267-
total_depth = 0
268-
header_names = []
269-
for header_key in index_headers.dtype.names:
270-
if header_key != "trace":
271-
unique_headers[header_key] = np.sort(np.unique(index_headers[header_key]))
272-
header_names.append(header_key)
273-
total_depth += 1
274-
275-
counter = create_counter(0, total_depth, unique_headers, header_names)
276-
277-
index_headers = create_trace_index(
278-
total_depth, counter, index_headers, header_names, dtype=dtype
279-
)
280-
283+
284+
# Get header names excluding 'trace' if it already exists
285+
header_names = [name for name in index_headers.dtype.names if name != "trace"]
286+
287+
if not header_names:
288+
# No headers to process, just add trace numbers sequentially
289+
trace_no_field = np.arange(1, len(index_headers) + 1, dtype=dtype)
290+
index_headers = rfn.append_fields(index_headers, "trace", trace_no_field, usemask=False)
291+
return index_headers
292+
293+
# Create trace field
294+
trace_no_field = np.zeros(index_headers.shape, dtype=dtype)
295+
index_headers = rfn.append_fields(index_headers, "trace", trace_no_field, usemask=False)
296+
297+
# Use a flat dictionary with tuple keys instead of nested dictionaries
298+
# This avoids pre-allocating memory for all possible combinations
299+
counter = {}
300+
301+
# Process each trace in a single pass
302+
for idx in range(len(index_headers)):
303+
# Create tuple key from header values for this trace
304+
key = tuple(index_headers[name][idx] for name in header_names)
305+
306+
# Increment counter for this combination and assign trace number
307+
counter[key] = counter.get(key, 0) + 1
308+
index_headers["trace"][idx] = counter[key]
309+
281310
t_stop = time.perf_counter()
282-
logger.debug("Time spent generating trace index: %.4f s", t_start - t_stop)
311+
logger.debug("Time spent generating trace index: %.4f s", t_stop - t_start)
283312
return index_headers
284313

285314

src/mdio/segy/parsers.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def parse_index_headers(
3939
current block. Array is of type byte_type except IBM32 which is mapped to FLOAT32.
4040
"""
4141
trace_count = segy_file.num_traces
42+
print(f"trace_count: {trace_count}")
4243
n_blocks = int(ceil(trace_count / block_size))
44+
print(f"n_blocks: {n_blocks}")
4345

4446
trace_ranges = []
4547
for idx in range(n_blocks):
@@ -69,5 +71,104 @@ def parse_index_headers(
6971
# This executes the lazy work.
7072
headers: list[HeaderArray] = list(lazy_work)
7173

74+
print("Concatenating headers...", flush=True)
75+
# raise Exception("Stop here")
76+
# ret = memory_efficient_concatenate(headers)
77+
ret = np.concatenate(headers)
78+
print("Finished!", flush=True)
7279
# Merge blocks before return
73-
return np.concatenate(headers)
80+
return ret
81+
82+
83+
def memory_efficient_concatenate(headers: list[HeaderArray]) -> HeaderArray:
84+
"""Memory-efficient concatenation for many small header arrays.
85+
86+
Pre-allocates the target array and copies data in place to avoid
87+
the memory fragmentation and intermediate allocations that occur
88+
with np.concatenate on many small arrays.
89+
90+
Args:
91+
headers: List of HeaderArray objects to concatenate
92+
93+
Returns:
94+
Single concatenated HeaderArray
95+
"""
96+
97+
# Heartbeat 1: Function entry
98+
with open("heartbeat_1_entry.txt", "w") as f:
99+
f.write("Entered memory_efficient_concatenate\n")
100+
f.flush()
101+
102+
if not headers:
103+
raise ValueError("Cannot concatenate empty list of arrays")
104+
105+
# Heartbeat 2: Before size calculation
106+
with open("heartbeat_2_calculating_size.txt", "w") as f:
107+
f.write(f"Starting size calculation for {len(headers)} arrays\n")
108+
f.flush()
109+
110+
# Calculate total size and get array metadata
111+
total_length = sum(len(arr) for arr in headers)
112+
first_array = headers[0]
113+
target_dtype = first_array.dtype
114+
115+
# Heartbeat 3: Before allocation
116+
estimated_size_mb = total_length * target_dtype.itemsize / 1024**2
117+
with open("heartbeat_3_before_allocation.txt", "w") as f:
118+
f.write(f"About to allocate {total_length:,} elements, estimated {estimated_size_mb:.1f} MB\n")
119+
f.flush()
120+
121+
print(f"Pre-allocating result array: {total_length:,} elements, "
122+
f"dtype={target_dtype}, estimated size={estimated_size_mb:.1f} MB")
123+
124+
# Pre-allocate the final array - this is the key optimization
125+
result = np.empty(total_length, dtype=target_dtype)
126+
127+
# Heartbeat 4: After allocation
128+
actual_size_mb = result.nbytes / 1024**2
129+
with open("heartbeat_4_allocated.txt", "w") as f:
130+
f.write(f"Successfully allocated array, actual size {actual_size_mb:.1f} MB\n")
131+
f.flush()
132+
133+
# Copy arrays sequentially into pre-allocated space
134+
current_pos = 0
135+
batch_size = 100 # Process in batches to provide progress feedback
136+
137+
# Heartbeat 5: Before copying loop
138+
with open("heartbeat_5_start_copying.txt", "w") as f:
139+
f.write("Starting copy loop\n")
140+
f.flush()
141+
142+
for i in range(0, len(headers), batch_size):
143+
batch_end = min(i + batch_size, len(headers))
144+
145+
# Process this batch
146+
for j in range(i, batch_end):
147+
arr = headers[j]
148+
if arr is None: # Skip if already processed
149+
continue
150+
151+
end_pos = current_pos + len(arr)
152+
153+
# Direct copy into pre-allocated space - no intermediate allocations
154+
result[current_pos:end_pos] = arr
155+
current_pos = end_pos
156+
157+
# Help garbage collector by clearing reference
158+
headers[j] = None
159+
160+
# Progress update and heartbeat for major milestones
161+
if batch_end % (5 * batch_size) == 0 or batch_end == len(headers):
162+
progress_pct = (batch_end / len(headers)) * 100
163+
with open(f"heartbeat_6_progress_{int(progress_pct)}.txt", "w") as f:
164+
f.write(f"Progress: {progress_pct:.1f}% ({batch_end:,}/{len(headers):,})\n")
165+
f.flush()
166+
print(f"Concatenation progress: {progress_pct:.1f}% ({batch_end:,}/{len(headers):,} arrays)")
167+
168+
# Heartbeat 7: Completion
169+
with open("heartbeat_7_complete.txt", "w") as f:
170+
f.write(f"Concatenation complete. Final size: {result.nbytes / 1024**2:.1f} MB\n")
171+
f.flush()
172+
173+
print(f"Concatenation complete. Final array size: {result.nbytes / 1024**2:.1f} MB")
174+
return result

0 commit comments

Comments
 (0)