- Rust CNNA high-performance Convolutional Neural Network (CNN) for garbage classification built from scratch using Rust and the tch crate (PyTorch bindings).
- This project classifies waste into 6 categories: cardboard, glass, metal, paper, plastic, and trash with 75% validation accuracy or more depending on how much you train it.
- Model: Custom CNN with batch normalization, dropout, and data augmentation
- Classes: 6 garbage types (cardboard, glass, metal, paper, plastic, trash)
- Dataset Size: 13,901 images
- Performance: 75% validation accuracy
- Training Platform: Kaggle GPUs
π DatasetSource: Garbage Classification Dataset on Kaggle
- Total Images: 13,901
- Training Set: 11,120 images (80%)
- Validation Set: 2,781 images (20%)
- Image Size: 200x200 RGB
- Data Augmentation: Random horizontal/vertical flips, rotation, resizing
- Pooling: Max pooling and Global Average Pooling
- Classification: Fully connected layers with dropout (0.15)
- Regularization: Batch normalization, gradient clipping
- Optimizer: Adam with initial LR=1e-4
- Loss Function: Cross-entropy loss
- Batch Size: 32 (configurable)
- Device Support: Auto-detection of CUDA GPUs
- Confusion Matrix: Detailed misclassification analysis
- Loss Curves: Training/validation loss tracking
- Export: CSV files for external analysis + Python plotting scripts
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
source ~/.cargo/env
curl -O https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcpu.zip
For GPU (Linux) with CUDA 12.1:
curl -O https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu121.zip
-unzip the zips
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=$LIBTORCH/lib:$LD_LIBRARY_PATH
|--------|---------|--------|--------|------------|----------|
| 1 | 500 | 2 | CPU | 8 | Quick test |
| 2 | 2,000 | 5 | GPU | 16 | Development |
| 3 | All | 100 | GPU | 32 | Full training |
| 4 | Custom | Custom | Custom | Custom | Experimentation |
- Training Time: ~1 hour 23 minutes (100 epochs, GPU)
- Final Training Loss: 0.6566
- Final Validation Loss: 0.7176
|-----------|-----------|--------|----------|---------|
| Cardboard | 0.79 | 0.83 | 0.81 | 436 |
| Glass | 0.67 | 0.68 | 0.67 | 440 |
| Metal | 0.81 | 0.70 | 0.75 | 478 |
| Paper | 0.74 | 0.72 | 0.73 | 506 |
| Plastic | 0.72 | 0.77 | 0.75 | 500 |
| Trash | 0.76 | 0.78 | 0.77 | 421 |
- Challenge Areas: Glass classification (visually similar to other materials)
- Common Confusions: Glass β Paper/Plastic (expected due to visual similarity)
- Learning Curve: Smooth convergence without overfitting
π Kaggle IntegrationLive Demo: Garbage Classifier Rust on Kaggle
This project was developed and tested on Kaggle Notebooks using free GPU access, demonstrating the feasibility of Rust-based deep learning on cloud platforms.
Check out this link look for full example
# Clone repository
!git clone https://github.com/Not-Buddy/garbage_classification.git
%cd garbage_classification
# Install Rust + dependencies (handled in notebook)
# Run training
!cargo run --release -- 3
## π Project Structure```
garbage_classification/
βββ src/
β βββ main.rs # Entry```int with CLI argument```rsing
β βββ menu.rs # Configuration```nagement +```taset loading
β βββ relu.rs # CNN```del architecture
β βββ train_model.rs # Training```op and optimization
β βββ training_validation.rs # Evaluation metrics an```lotting
β βββ visualize_data.rs # Data loading and batch```sualization
βββ Cargo.toml # Dependencies and project configuration
βββ README.md # This file
βββ plot_results.py # Generated Python plotting script
π§ Technical Implementation### Dataset Loading- Parallel Processing: Rayon for multi-threaded image loading
- Augmentation Pipeline: Random flips, rotation, resizing
- Memory Efficiency: Tensor operations in CHW format
- Progress Tracking: Real-time processing counter
CNN {
features: Conv2d(32) -> BN -> ReLU -> Conv2d(32) -> BN -> ReLU -> MaxPool
Conv2d(64) -> BN -> ReLU -> Conv2d(64) -> BN -> ReLU -> MaxPool
Conv2d(128) -> BN -> ReLU -> Conv2d(128) -> BN -> ReLU -> MaxPool,
gap: GlobalAveragePooling2d,
classifier: Linear(128) -> ReLU -> Dropout(0.15) -> Linear(6)
}
- Learning Rate Scheduling: Exponential decay
- Early Stopping: Validation-based (configurable)
- Gradient Clipping: Prevents exploding gradients
- Progress Monitoring: Real-time loss and accuracy tracking
training_losses.csv
- Loss data for plottingconfusion_matrix.csv
- Confusion matrix dataevaluation_results.csv
- Detailed predictions with probabilitiestraining_stats.csv
- Complete training metricsplot_results.py
- Python script for visualizationgarbage_classifier_X_epochs_Y_samples.pt
- Saved model weights
- Attention Mechanisms: Focus on distinguishing features
- Ensemble Methods: Multiple model voting
- Advanced Augmentation: MixUp, CutMix, AutoAugment
- Hyperparameter Tuning: Automated search
- Model Serving: REST API for inference
- Cross-Validation: More robust evaluation
- Concurrency: Excellent parallel processing support
- Ecosystem: Growing ML/DL ecosystem with tch
- Deployment: Single binary, no runtime dependencies
- Reliability: Compile-time guarantees prevent common ML bugs
- Create your feature branch (
git checkout -b feature/amazing-feature
) - Commit your changes (
git commit -m 'Add amazing feature'
) - Push to the branch (
git push origin feature/amazing-feature
) - Open a Pull Request
π LicenseThis project is open source and available under the MIT License.
π Acknowledgments- tch crate for PyTorch bindings
- Kaggle for free GPU access
- Dataset contributors for the garbage classification dataset
- Rust community for excellent documentation and support
β Star this repository if you found it helpful!
For questions or collaboration opportunities, feel free to open an issue or reach out directly.