Add FocalLoss, class weights, and enhanced tile filtering#503
Add FocalLoss, class weights, and enhanced tile filtering#503
Conversation
…cover training - Add FocalLoss class to train.py with configurable alpha, gamma, ignore_index (supports int or False to disable), reduction, and per-class weights - Add get_loss_function() helper supporting 'crossentropy' and 'focal' losses with flexible ignore_index (int or False) and optional class weights - Add compute_class_weights() to compute per-class weights from label tiles with inverse-frequency mode, custom multipliers, and weight capping - Wire new loss functions into train_segmentation_model() with parameters: loss_function, ignore_index, use_class_weights, class_weights, custom_class_multipliers, max_class_weight, use_inverse_frequency, focal_alpha, focal_gamma - Add min_feature_ratio parameter to export_geotiff_tiles() for filtering tiles with insufficient non-background pixels during tile export - Export new public functions from __init__.py: FocalLoss, get_loss_function, compute_class_weights, plus landcover module exports - All new parameters have backward-compatible defaults Closes #335
|
🚀 Deployed on https://697c96880a5be10bbd1b5820--opengeos.netlify.app |
There was a problem hiding this comment.
Pull request overview
This PR extends the training and tiling utilities to better support highly imbalanced, discrete landcover segmentation workflows, and exposes the new landcover-specific helpers at the package root.
Changes:
- Adds a general-purpose
FocalLoss,get_loss_function, andcompute_class_weightsintrain.py, and wires them intotrain_segmentation_modelto support focal loss, class weighting, and flexibleignore_index. - Extends
export_geotiff_tilesinutils.pywith amin_feature_ratioparameter to skip tiles with too few non-background pixels and report ratio-based skipping statistics. - Updates
__init__.pyto export the new training helpers and landcover-specific utilities (LandcoverCrossEntropyLoss,landcover_iou,get_landcover_loss_function,train_segmentation_landcover,export_landcover_tiles).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
geoai/train.py |
Introduces FocalLoss, a configurable loss factory get_loss_function, class-weight computation via compute_class_weights, and extends train_segmentation_model to use these for class-imbalanced landcover training. |
geoai/utils.py |
Adds min_feature_ratio validation, per-tile ratio filtering, and summary stats reporting to export_geotiff_tiles to drop mostly-background tiles. |
geoai/__init__.py |
Re-exports the new loss/weight utilities from train.py and landcover-specific training and tiling helpers for easier package-level access. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| return weights | ||
|
|
||
|
|
||
| def train_segmentation_model( |
There was a problem hiding this comment.
The function signature declares a return type of torch.nn.Module, but the implementation does not return a model object (it relies on side effects and saving to output_dir), and the docstring states that it returns None. Please align the return type annotation and documentation with the actual behavior (either return the model or change the annotation/docstring to indicate that the function returns None).
geoai/utils.py
Outdated
| print( | ||
| f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}" | ||
| ) |
There was a problem hiding this comment.
When in_class_data is a raster, stats['feature_pixels'] is incremented as soon as a tile has any features, but tiles can still be skipped later by the min_feature_ratio filter. As a result, the "Average feature pixels per tile" summary can include pixels from tiles that were not actually exported, which makes this statistic misleading; it would be more accurate to count feature pixels only for tiles that are written.
| print( | |
| f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}" | |
| ) | |
| if min_feature_ratio is False: | |
| print( | |
| f"Average feature pixels per tile: {stats['feature_pixels']/stats['tiles_with_features']:.1f}" | |
| ) | |
| else: | |
| print( | |
| "Average feature pixels per tile is not reported when a min_feature_ratio filter is applied." | |
| ) |
| class_counts: Counter = Counter() | ||
| total_pixels = 0 |
There was a problem hiding this comment.
compute_class_weights uses Counter here but Counter is never imported in this module, so calling this function will raise a NameError at runtime. Please import Counter from the collections module (or remove the annotation and use a plain dict) so that this function works correctly.
|
Hello Professor Wu, I tried reinstalling geoai on another system and the modules I added didnt seem to work properly... I had to change a few things in them. Sorry about that, I cant push a working version next week I think. |
|
@gw0ods Can you open a new issue and describe what is not working? |
on it! |
Summary
Implements the enhancements from #335 to support training with discrete class landcover data.
Changes
New in
train.py:FocalLossclass — focal loss for class imbalance (FL = -α(1-pt)^γ log(pt)), with configurable alpha, gamma, ignore_index (int orFalseto disable), reduction, and per-class weights. UsesF.log_softmax+F.nll_lossinternally.get_loss_function()— factory returning configuredCrossEntropyLossorFocalLosswith flexibleignore_indexand optional class weights.compute_class_weights()— computes per-class weights from label tiles with inverse-frequency mode, custom multipliers, and weight capping.train_segmentation_model()— wired with new parameters:loss_function,ignore_index,use_class_weights,class_weights,custom_class_multipliers,max_class_weight,use_inverse_frequency,focal_alpha,focal_gamma.New in
utils.py:min_feature_ratioparameter onexport_geotiff_tiles()— filters out tiles where the ratio of non-background pixels is below a threshold (only applies whenskip_empty_tiles=Trueand label data is provided). Summary stats are reported at the end.New in
__init__.py:FocalLoss,get_loss_function,compute_class_weightsfromtrain.pyLandcoverCrossEntropyLoss,landcover_iou,get_landcover_loss_function,train_segmentation_landcoverfromlandcover_train.pyexport_landcover_tilesfromlandcover_utils.pyBackward Compatibility
All new parameters have sensible defaults — existing code continues to work without changes.
Closes #335
Closes #335