@@ -161,3 +161,71 @@ trainer.fit(
161161```
162162
163163</details >
164+
165+ ## Model Registry Mixins
166+
167+ Lightning Models provides mixin classes that simplify pushing models to and pulling models from the registry.
168+ These mixins can be integrated directly into the model classes.
169+
170+ ### Available Mixins
171+
172+ 1 . ** PickleRegistryMixin** : For serializing any Python class with pickle
173+ 2 . ** PyTorchRegistryMixin** : For PyTorch models, preserving both weights and constructor arguments
174+
175+ Using these mixins provides several advantages:
176+
177+ - Direct integration into the model classes
178+ - Simplified save/load workflow
179+ - Automatic handling of model metadata and constructor arguments
180+ - Version management support
181+
182+ ### Using ` PickleRegistryMixin `
183+
184+ Add the mixin to a Python class for seamless registry integration:
185+
186+ ``` python
187+ from litmodels.integrations.mixins import PickleRegistryMixin
188+
189+
190+ class MyModel (PickleRegistryMixin ):
191+ def __init__ (self , param1 , param2 ):
192+ self .param1 = param1
193+ self .param2 = param2
194+ # Your model initialization code
195+
196+
197+ # Create and push a model instance
198+ model = MyModel(param1 = 42 , param2 = " hello" )
199+ model.push_to_registry(name = " my-org/my-team/my-model" )
200+
201+ # Later, pull the model
202+ loaded_model = MyModel.pull_from_registry(name = " my-org/my-team/my-model" )
203+ ```
204+
205+ ### Using ` PyTorchRegistryMixin `
206+
207+ This mixin preserves both the model architecture and weights:
208+
209+ ``` python
210+ import torch
211+ from litmodels.integrations.mixins import PyTorchRegistryMixin
212+
213+
214+ # Important: PyTorchRegistryMixin must be first in the inheritance order
215+ class MyTorchModel (PyTorchRegistryMixin , torch .nn .Module ):
216+ def __init__ (self , input_size , hidden_size = 128 ):
217+ super ().__init__ ()
218+ self .linear = torch.nn.Linear(input_size, hidden_size)
219+ self .activation = torch.nn.ReLU()
220+
221+ def forward (self , x ):
222+ return self .activation(self .linear(x))
223+
224+
225+ # Create and push the model
226+ model = MyTorchModel(input_size = 784 )
227+ model.push_to_registry(name = " my-org/my-team/torch-model" )
228+
229+ # Pull the model with the same architecture
230+ loaded_model = MyTorchModel.pull_from_registry(name = " my-org/my-team/torch-model" )
231+ ```
0 commit comments