Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
137a37f
initial code dump
divyashreepathihalli Aug 11, 2025
8b37482
clean up the implementation of the distillation api
divyashreepathihalli Aug 11, 2025
8252b8f
code reformat
divyashreepathihalli Aug 11, 2025
9bdec23
final clean up
divyashreepathihalli Aug 11, 2025
6efecee
pre commit
divyashreepathihalli Aug 11, 2025
36930e8
Merge branch 'keras-team:master' into distillation-api
divyashreepathihalli Aug 11, 2025
1f73a69
address gemini review comments
divyashreepathihalli Aug 12, 2025
88c2468
address gemini review comments
divyashreepathihalli Aug 12, 2025
9de5809
add a way to save trained student model
divyashreepathihalli Aug 12, 2025
b954718
disable tests in numpy and openvino backends
divyashreepathihalli Aug 12, 2025
bf6219a
pre commit
divyashreepathihalli Aug 12, 2025
b7e51a9
address comments
divyashreepathihalli Aug 15, 2025
e8229c2
address comments
divyashreepathihalli Aug 15, 2025
387595a
run pre-commit
divyashreepathihalli Aug 16, 2025
4d6610a
update distiller and strategies
divyashreepathihalli Aug 18, 2025
a109178
code reformat
divyashreepathihalli Aug 18, 2025
7c13687
Merge branch 'keras-team:master' into distillation-api
divyashreepathihalli Aug 25, 2025
5b6bf03
clean up
divyashreepathihalli Aug 25, 2025
de73fa6
code reformat
divyashreepathihalli Aug 25, 2025
5cd56bf
remove multi output distiilation
divyashreepathihalli Aug 28, 2025
0b2d88f
clean up after merge
divyashreepathihalli Aug 28, 2025
b00d4a4
Merge branch 'keras-team:master' into distillation-api
divyashreepathihalli Aug 29, 2025
9d8242c
address comments
divyashreepathihalli Aug 29, 2025
9c6a70c
update file names
divyashreepathihalli Sep 2, 2025
a7d0b54
subclass logits distillation loss from feature distillation loss
divyashreepathihalli Sep 3, 2025
1607807
update docstrings
divyashreepathihalli Sep 3, 2025
5659310
add validation for feature extraction setup
divyashreepathihalli Sep 3, 2025
d3b27b3
Fix distiller API to accept single strategy or list of strategies
divyashreepathihalli Sep 4, 2025
bbe0868
minor fixes
divyashreepathihalli Sep 4, 2025
7775079
fix tests
divyashreepathihalli Sep 5, 2025
a078fb4
address reveiw comments
divyashreepathihalli Sep 9, 2025
e2f7f6b
oh
divyashreepathihalli Sep 9, 2025
a730c4a
Merge branch 'keras-team:master' into distillation-api
divyashreepathihalli Sep 9, 2025
ed6768f
Merge branch 'keras-team:master' into distillation-api
divyashreepathihalli Sep 23, 2025
649d7f2
Merge branch 'keras-team:master' into distillation-api
divyashreepathihalli Sep 30, 2025
a5f4605
address review comments
divyashreepathihalli Sep 30, 2025
df07758
address review comments
divyashreepathihalli Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from keras import config as config
from keras import constraints as constraints
from keras import datasets as datasets
from keras import distillation as distillation
from keras import distribution as distribution
from keras import dtype_policies as dtype_policies
from keras import export as export
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras import config as config
from keras import constraints as constraints
from keras import datasets as datasets
from keras import distillation as distillation
from keras import distribution as distribution
from keras import dtype_policies as dtype_policies
from keras import export as export
Expand Down
19 changes: 19 additions & 0 deletions keras/api/_tf_keras/keras/distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.distillation.distiller import Distiller as Distiller
from keras.src.distillation.strategies import (
BaseDistillationStrategy as BaseDistillationStrategy,
)
from keras.src.distillation.strategies import (
FeatureDistillation as FeatureDistillation,
)
from keras.src.distillation.strategies import (
LogitsDistillation as LogitsDistillation,
)
from keras.src.distillation.strategies import (
MultiOutputDistillation as MultiOutputDistillation,
)
19 changes: 19 additions & 0 deletions keras/api/distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.distillation.distiller import Distiller as Distiller
from keras.src.distillation.strategies import (
BaseDistillationStrategy as BaseDistillationStrategy,
)
from keras.src.distillation.strategies import (
FeatureDistillation as FeatureDistillation,
)
from keras.src.distillation.strategies import (
LogitsDistillation as LogitsDistillation,
)
from keras.src.distillation.strategies import (
MultiOutputDistillation as MultiOutputDistillation,
)
1 change: 1 addition & 0 deletions keras/src/distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Distillation module for knowledge distillation in Keras."""
331 changes: 331 additions & 0 deletions keras/src/distillation/distiller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
import keras
from keras.src.api_export import keras_export
from keras.src.models.model import Model


@keras_export("keras.distillation.Distiller")
class Distiller(Model):
"""Knowledge Distillation model.

This class implements knowledge distillation by combining a teacher model
and a student model with configurable distillation strategies.

The Distiller integrates seamlessly with Keras's training infrastructure
by overriding the _compute_loss method, allowing standard model.fit(),
model.evaluate(), and model.predict() workflows to work correctly.

Args:
teacher: The teacher model (will be frozen during training).
student: The student model to be trained.
strategies: List of distillation strategies to apply.
student_loss_fn: Loss function for student predictions. Defaults to
sparse categorical crossentropy.
alpha: Weight for combining student loss and distillation loss.
alpha=1.0 means only student loss, alpha=0.0 means only
distillation loss.
temperature: Temperature for softmax in distillation (used by
strategies).
name: Name of the distiller model.

Examples:

**Basic Knowledge Distillation:**

```python
import keras
import numpy as np
from keras.distillation import Distiller, LogitsDistillation

# Create teacher and student models
teacher = keras.Sequential([
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])

student = keras.Sequential([
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])

# Create distillation strategy
strategy = LogitsDistillation(temperature=3.0)

# Create distiller
distiller = Distiller(
teacher=teacher,
student=student,
strategies=[strategy],
alpha=0.7, # 70% student loss, 30% distillation loss
temperature=3.0
)

# Compile and train
distiller.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy'
)

# Generate dummy data
x_train = np.random.random((1000, 20))
y_train = np.random.randint(0, 10, (1000,))

# Train the distiller
distiller.fit(x_train, y_train, epochs=10, batch_size=32)

# Use the trained student model
predictions = distiller.predict(x_train[:5])
```

**Multi-Strategy Distillation:**

```python
from keras.distillation import (
Distiller, LogitsDistillation, FeatureDistillation
)

# Multiple distillation strategies
strategies = [
LogitsDistillation(temperature=4.0),
FeatureDistillation(loss_type="mse")
]

distiller = Distiller(
teacher=teacher,
student=student,
strategies=strategies,
alpha=0.5
)
```

**Multi-Output Model Distillation:**

```python
from keras.distillation import MultiOutputDistillation

# For models with multiple outputs
multi_strategy = MultiOutputDistillation(
output_strategies={
0: LogitsDistillation(temperature=3.0, output_index=0),
1: LogitsDistillation(temperature=2.0, output_index=1)
},
weights={0: 1.0, 1: 0.5}
)

distiller = Distiller(
teacher=multi_output_teacher,
student=multi_output_student,
strategies=[multi_strategy]
)
```

**Custom Loss Function:**

```python
distiller = Distiller(
teacher=teacher,
student=student,
strategies=[LogitsDistillation()],
student_loss_fn=keras.losses.CategoricalCrossentropy(),
alpha=0.8
)
```
"""

def __init__(
self,
teacher,
student,
strategies,
student_loss_fn=None,
alpha=0.5,
temperature=3.0,
name="distiller",
**kwargs,
):
super().__init__(name=name, **kwargs)

# Validate inputs
self._validate_models(teacher, student)

# Store configuration
self.teacher = teacher
self.student = student
self.strategies = (
strategies if isinstance(strategies, list) else [strategies]
)
self.alpha = alpha
self.temperature = temperature

# Set up student loss function
if student_loss_fn is None:
self.student_loss_fn = keras.losses.SparseCategoricalCrossentropy()
else:
self.student_loss_fn = student_loss_fn

# Freeze teacher model
self.teacher.trainable = False

# Initialize loss tracking metrics
self.student_loss_tracker = keras.metrics.Mean(name="student_loss")
self.distillation_loss_tracker = keras.metrics.Mean(
name="distillation_loss"
)
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")

def _validate_models(self, teacher, student):
"""Validate that teacher and student are Keras models."""
if not isinstance(teacher, keras.Model):
raise ValueError(
f"Teacher must be a keras.Model, got {type(teacher)}"
)
if not isinstance(student, keras.Model):
raise ValueError(
f"Student must be a keras.Model, got {type(student)}"
)

def call(self, inputs, training=None, **kwargs):
"""Forward pass returns student predictions."""
return self.student(inputs, training=training, **kwargs)

def _compute_loss(
self, x=None, y=None, y_pred=None, sample_weight=None, training=None
):
"""Compute combined distillation loss.

This method integrates distillation into Keras's standard training
workflow.
"""
# Get student predictions
if y_pred is None:
y_pred = self(x, training=training)

# Get teacher predictions (no gradients)
teacher_outputs = self.teacher(x, training=False)
teacher_outputs = keras.ops.stop_gradient(teacher_outputs)

# Normalize outputs for consistent handling
student_outputs = (
[y_pred] if not isinstance(y_pred, (list, tuple)) else list(y_pred)
)
teacher_outputs = (
[teacher_outputs]
if not isinstance(teacher_outputs, (list, tuple))
else list(teacher_outputs)
)

# Validate outputs with strategies
for strategy in self.strategies:
strategy.validate_outputs(teacher_outputs, student_outputs)

# Compute student loss
if self.alpha > 0:
if hasattr(self, "compiled_loss") and self.compiled_loss:
student_loss = self.compiled_loss(
y, y_pred, sample_weight=sample_weight
)
else:
# Fallback to using student_loss_fn directly
# Handle multi-output case
if isinstance(y_pred, (list, tuple)):
# For multi-output models, use the first output for student
# loss
# This is a simplified approach for compatibility
if isinstance(y, (list, tuple)):
student_loss = self.student_loss_fn(y[0], y_pred[0])
else:
student_loss = self.student_loss_fn(y, y_pred[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The fallback logic for calculating the student loss in _compute_loss for multi-output models is overly simplistic as it always defaults to using the first output (y_pred[0]). This might not align with user expectations for all multi-output scenarios and could lead to incorrect training behavior if model.compile() is not called with a loss that properly handles multiple outputs.

While the primary path using self.compiled_loss is correct, this fallback could be made more robust. Consider raising a more specific error if a multi-output model is used without a compiled loss, or clarifying this behavior more explicitly in the documentation.

else:
student_loss = self.student_loss_fn(y, y_pred)
else:
student_loss = 0.0

# Compute distillation loss
distillation_loss = 0.0
for strategy in self.strategies:
distillation_loss += strategy.compute_loss(
teacher_outputs, student_outputs
)

# Combine losses
total_loss = (
self.alpha * student_loss + (1 - self.alpha) * distillation_loss
)

# Update metrics
self.student_loss_tracker.update_state(
student_loss if self.alpha > 0 else 0.0
)
self.distillation_loss_tracker.update_state(
distillation_loss if self.alpha < 1 else 0.0
)
self.total_loss_tracker.update_state(total_loss)

return total_loss

@property
def metrics(self):
"""Return metrics for monitoring."""
# Combine parent metrics with our loss trackers
parent_metrics = []
if hasattr(super(), "metrics"):
for metric in super().metrics:
if hasattr(metric, "variables") and hasattr(
metric, "update_state"
):
parent_metrics.append(metric)

return parent_metrics + [
self.student_loss_tracker,
self.distillation_loss_tracker,
self.total_loss_tracker,
]

def reset_metrics(self):
"""Reset all metrics."""
try:
super().reset_metrics()
except AttributeError:
pass

self.student_loss_tracker.reset_state()
self.distillation_loss_tracker.reset_state()
self.total_loss_tracker.reset_state()

def get_config(self):
"""Get model configuration for serialization."""
config = super().get_config()
config.update(
{
"teacher": keras.utils.serialize_keras_object(self.teacher),
"student": keras.utils.serialize_keras_object(self.student),
"strategies": [
keras.utils.serialize_keras_object(s)
for s in self.strategies
],
"student_loss_fn": keras.utils.serialize_keras_object(
self.student_loss_fn
),
"alpha": self.alpha,
"temperature": self.temperature,
}
)
return config

@classmethod
def from_config(cls, config):
"""Create model from configuration."""
config = config.copy()
config["teacher"] = keras.utils.deserialize_keras_object(
config["teacher"]
)
config["student"] = keras.utils.deserialize_keras_object(
config["student"]
)
config["strategies"] = [
keras.utils.deserialize_keras_object(s)
for s in config["strategies"]
]
config["student_loss_fn"] = keras.utils.deserialize_keras_object(
config["student_loss_fn"]
)
return cls(**config)
Loading
Loading