Skip to content

Improvement on the Training API #5194

@vfdev-5

Description

@vfdev-5
  • Improve the current gemma example to be ideomatic NNX, related to Updates in gemma example #4913
  • Make it work with any Bonsai LLM
  • Support explicit sharding
  • Add Hijax Support
  • Expose Python API
    • Focus is readability, simple recipe for training
    • Multi-host support?
    • Adding profiling features + a profiling guide?
# version 1 (unlikely)
trainer = Trainer(model, optimizer, loss_fn, train_step, ...)
trainer.run(train_data, num_epoch2=10)

or

# version 2
from nnx.training import train_image_model, finetune_llm

model = ...
dataset = ...

train_image_model(model, dataset, eval_steps=100, ...)
train_image_model(model, dataset, config)
  • Adding profiling features + a profiling guide.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions