Skip to content

Add memory-efficient tiled inference with spline blending and D4 TTA#603

Open
jayakumarpujar wants to merge 7 commits intoopengeos:mainfrom
jayakumarpujar:feature/inference-smoothing-tta
Open

Add memory-efficient tiled inference with spline blending and D4 TTA#603
jayakumarpujar wants to merge 7 commits intoopengeos:mainfrom
jayakumarpujar:feature/inference-smoothing-tta

Conversation

@jayakumarpujar
Copy link
Collaborator

Summary

Implements a generic, memory-efficient tiled inference pipeline for GeoTIFF rasters as a new geoai/inference.py module, addressing issue #87.

  • Windowed rasterio I/O - reads tiles directly via rasterio windows, avoiding full-image input memory allocation
  • Four blending strategies via BlendMode enum: none, linear, cosine, and spline (new, from issue author's approach using scipy.signal.windows.triang)
  • D4 test-time augmentation - optional 8-fold augmentation (identity, 3 rotations, h-flip, v-flip, 2 diagonal flips) for improved predictions at 8x compute cost
  • Configurable pre/post-processing hooks - preprocess_fn and postprocess_fn callables for model-specific normalization
  • GeoTIFF metadata preservation - CRS, transform, and nodata are carried through to output

Key fixes vs prior attempt (PR #502)

  • Dedicated geoai/inference.py module instead of adding to the already-large utils.py
  • Proper boundary tile handling (crop predictions to actual tile dimensions, preventing shape mismatch)
  • Input validation (overlap >= tile_size raises ValueError)
  • num_classes parameter actually used to size the output accumulator
  • No redundant device transfers
  • Pure torch.rot90/torch.flip for D4 transforms (no torchvision dependency)

Files changed

File Change
geoai/inference.py New module: BlendMode, create_weight_mask, d4_forward, d4_inverse, d4_tta_forward, predict_geotiff
geoai/__init__.py Lazy import registration for all public symbols
tests/test_inference.py 35 tests: imports, signatures, weight mask behavior, D4 roundtrips, validation
docs/inference.md mkdocstrings API doc page
mkdocs.yml Nav entry for inference module

Usage

from geoai.inference import predict_geotiff

predict_geotiff(
    model=my_model,
    input_raster="input.tif",
    output_raster="output.tif",
    tile_size=256,
    overlap=64,
    blend_mode="spline",
    tta=False,
)

Test plan

  • All 35 unit tests pass (pytest tests/test_inference.py -v)
  • Lazy imports verified (from geoai import predict_geotiff, BlendMode)
  • Weight mask symmetry, shape, dtype, and value range validated for all 4 modes
  • D4 roundtrip verified (forward + inverse = identity for all 8 transforms)
  • Input validation tested (FileNotFoundError, ValueError for invalid overlap)

Closes #87

References

Implement a generic inference module (geoai/inference.py) for running
any PyTorch segmentation or regression model on large GeoTIFF rasters
with smooth tile blending and optional test-time augmentation.

Key features:
- Windowed rasterio I/O to avoid loading full input images into memory
- Four blending strategies (none, linear, cosine, spline) via BlendMode enum
- D4 dihedral group test-time augmentation (8-fold) for improved predictions
- Configurable pre/post-processing hooks for model-specific normalization
- Preserves GeoTIFF metadata (CRS, transform, nodata)

Closes opengeos#87
- Fix critical bug: weight normalization crashed when num_classes > 1
  due to boolean index shape mismatch. Replace with np.where broadcast.
- Add model output shape validation on first batch to catch mismatched
  num_classes early with a clear error message.
- Handle models returning (B, H, W) by auto-expanding to (B, 1, H, W).
- Avoid reopening input raster; capture profile inside first with block.
- Add tests for multi-class normalization and zero-weight nodata.
@jayakumarpujar jayakumarpujar force-pushed the feature/inference-smoothing-tta branch from 8ac781a to 5852fd5 Compare March 6, 2026 06:09
@jayakumarpujar jayakumarpujar requested a review from giswqs March 6, 2026 06:19
@jayakumarpujar jayakumarpujar self-assigned this Mar 6, 2026
@giswqs
Copy link
Member

giswqs commented Mar 6, 2026

@jayakumarpujar Thank you for adding this new feature. Can you add a notebook example to demostrate it?

jayakumarpujar and others added 4 commits March 6, 2026 11:21
- Spline blend now tapers only in overlap zones (flat plateau in center),
  matching linear/cosine behavior
- Replace random-weight demo models with fixed Sobel/directional kernels
  that produce spatially varying output for visible blending comparison
- Fix naip_preprocess to handle both uint8 and uint16 data
  (was dividing uint8 values by 10000, producing near-zero input)
- Replace demo models with deterministic designs using global average
  pooling (AdaptiveAvgPool2d) so predictions are tile-dependent and
  seam artifacts are clearly visible when blending is disabled
- Use band-ratio features (NDVI, greenness, brightness) for spatially
  varying output across the NAIP landscape
- MulticlassModel uses fully deterministic fixed weights for
  reproducible multi-class maps
@jayakumarpujar
Copy link
Collaborator Author

Screenshot from 2026-03-06 11-50-40

@giswqs I have updated with a notebook example. Please let me know if anything needs to be changed.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new geoai.inference module that provides a generic, windowed GeoTIFF inference pipeline with overlap blending and optional D4 test-time augmentation, plus docs/tests and top-level lazy exports.

Changes:

  • Introduce geoai/inference.py with predict_geotiff, BlendMode, weight-mask generation, and D4 TTA helpers.
  • Register new public API via lazy imports in geoai/__init__.py and add dedicated unit tests.
  • Add mkdocs API page + example notebook and wire them into mkdocs.yml.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
geoai/inference.py New tiled inference implementation (blending + D4 TTA) and public API symbols
geoai/__init__.py Lazy-export registration for inference symbols and submodule
tests/test_inference.py Unit tests for imports/exports, signatures, weight masks, D4 transforms, validation
docs/inference.md mkdocstrings entry for the new inference module
docs/examples/smooth_inference.ipynb Example notebook demonstrating blending strategies + TTA usage
mkdocs.yml Navigation entries for the new notebook and inference docs page

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +71 to +79
window = np.ones(window_size, dtype=np.float64)
# Left taper: 0 .. overlap -> ramp from 0 to 1
ramp = np.linspace(0.0, 1.0, overlap, endpoint=False)
# Apply half-cosine shaping then power for smoothness
ramp = (0.5 * (1.0 - np.cos(np.pi * ramp))) ** power
window[:overlap] = ramp
# Right taper: mirror of left
window[-overlap:] = ramp[::-1]
return window
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_spline_window_1d() uses np.linspace(..., endpoint=False), so the last pixel in the taper never reaches weight 1.0, and the very next pixel jumps to 1.0 due to the center being initialized to ones. This introduces an avoidable discontinuity at the overlap boundary; consider using an endpoint-inclusive ramp (or explicitly setting the inner boundary weight to 1.0) so the taper joins the flat region continuously.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Fixed by switching to np.linspace(0.0, 1.0, overlap + 1, endpoint=True)[1:] so the taper reaches exactly 1.0 at the boundary - no more jump into the flat centre region. Added tests verifying continuity at the boundary for multiple overlap sizes.

Comment on lines +68 to +79
if overlap <= 0:
return np.ones(window_size, dtype=np.float64)

window = np.ones(window_size, dtype=np.float64)
# Left taper: 0 .. overlap -> ramp from 0 to 1
ramp = np.linspace(0.0, 1.0, overlap, endpoint=False)
# Apply half-cosine shaping then power for smoothness
ramp = (0.5 * (1.0 - np.cos(np.pi * ramp))) ** power
window[:overlap] = ramp
# Right taper: mirror of left
window[-overlap:] = ramp[::-1]
return window
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For BlendMode.SPLINE, _spline_window_1d() writes window[:overlap] and window[-overlap:] without guarding against overlap > window_size // 2. In that case the left/right taper regions overlap and overwrite each other, producing a distorted (and potentially non-monotone) window. Either constrain spline mode to overlap <= tile_size // 2 (raise a clear ValueError) or rework the construction to handle large overlaps similarly to the linear/cosine implementations.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid point - the left/right tapers were overwriting each other when overlap > tile_size // 2. Added a ValueError guard for spline mode with a message suggesting linear or cosine blending for larger overlaps (those modes already handle this correctly via np.minimum).

Comment on lines +53 to +58
def _spline_window_1d(window_size: int, overlap: int, power: int = 2) -> np.ndarray:
"""Create a 1D spline window that tapers only in the overlap zones.

The window is 1.0 in the non-overlapping centre and smoothly tapers
to 0 at the edges using a raised half-cosine shaped by *power*.

Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code labels this as a “spline” window and references Smoothly-Blend-Image-Patches, but _spline_window_1d() is currently a powered raised-cosine taper (no triangular/spline windowing like the referenced approach). This mismatch is confusing for users trying to reproduce the referenced method; either implement the intended spline/triangular-based window or rename/re-document the mode to accurately describe the actual weighting function.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging. The naming is intentional - "spline" here refers to the smooth blending concept rather than the exact triangular window from the referenced repo. The docstring already described the actual implementation ("raised half-cosine shaped by power"). I've updated the enum docstring to say "powered raised-cosine taper" for additional clarity.

Comment on lines +496 to +502
out_count = output_array.shape[0] if output_array.ndim == 3 else 1
profile.update(
count=out_count,
dtype=output_dtype,
nodata=output_nodata,
compress=compress,
)
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_dtype is configurable, but the default output_nodata=-9999.0 is not representable for many integer dtypes (e.g. uint8), which will likely cause rasterio to raise when opening the output dataset with profile.update(nodata=...). Consider validating that output_nodata fits output_dtype (and raising a clear error) or choosing a dtype-appropriate default nodata when output_dtype is changed.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch! Added validation that output_nodata fits within the valid range of output_dtype for integer types, raising a clear ValueError with guidance on valid alternatives.

Comment on lines +313 to +323
"# No blending — hard tile boundaries (last-write-wins)\n",
"predict_geotiff(\n",
" model=model,\n",
" input_raster=input_path,\n",
" output_raster=output_no_blend,\n",
" tile_size=256,\n",
" overlap=64,\n",
" batch_size=4,\n",
" num_classes=1,\n",
" blend_mode=\"none\",\n",
" preprocess_fn=naip_preprocess,\n",
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This notebook describes blend_mode="none" as “last-write-wins”, but the current predict_geotiff() implementation always normalizes accumulated predictions, so none results in uniform averaging across overlaps (not hard boundaries). Either update the notebook wording to match the actual behavior, or change the implementation so BlendMode.NONE truly overwrites in overlap regions.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is addressed together with comment 6 - the BlendMode.NONE docstring has been corrected to "uniform averaging", which is what the implementation actually does.

Comment on lines +37 to +49
class BlendMode(str, Enum):
"""Blending strategy for overlapping tile predictions.

Attributes:
NONE: No blending; last-write-wins.
LINEAR: Linear ramp from 0 at edges to 1 at center.
COSINE: Raised-cosine (Hann) taper in the overlap region.
SPLINE: Spline (triangular-based) window for smooth transitions.
"""

NONE = "none"
LINEAR = "linear"
COSINE = "cosine"
Copy link

Copilot AI Mar 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BlendMode.NONE is documented as “No blending; last-write-wins”, but predict_geotiff() always does weighted accumulation + normalization, so blend_mode="none" actually averages overlapping tiles (uniform weights) rather than overwriting. Either implement true last-write-wins behavior for NONE (skip output_sum/weight_sum blending) or update the enum/docstrings (and docs notebook) to reflect that NONE means uniform averaging.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right - the implementation uses weighted accumulation with uniform weights (all 1.0), which produces averaging in overlap regions, not last-write-wins. Updated the docstring to accurately say "Uniform averaging - all pixels are weighted equally (1.0), so overlapping tiles are simply averaged without tapering."

… validation

- Fix endpoint discontinuity in _spline_window_1d by using endpoint=True
- Guard spline mode against overlap > tile_size // 2 (raises ValueError)
- Update BlendMode.NONE docstring: uniform averaging, not last-write-wins
- Validate output_nodata fits output_dtype for integer types
- Add 11 new tests covering all fixes (46 total, all passing)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

A more efficient inference function + smoothing

3 participants