-
-
Notifications
You must be signed in to change notification settings - Fork 16
Training API #172
Description
Feature Description
Add Useful API for training cycle so that user doesn't need to code training loop from scratch each time.
Add new methods in Job:
-
Job.request()
Same as we currently do inside the.start()method (auth, download of model and plan). -
trainingProcess = Job.train(trainingPlan, parameters)
Helper for training loop
trainingProcess - object would contain current epoch, batch, modelParameters
trainingPlan - string
parameters - dict of values:
planInputs: list ofPlanInputSpec
planOutputs: list ofPlanOutputSpec
data: tensor
target: (optional) tensor
epochs: number - how many epoch to train
batchSize: number
stepsPerEpoch: (optional) number - max number of steps per epoch
events: list of handlers: 'start', 'end', 'epochStart', 'epochEnd', 'batchStart', 'batchEnd', 'error'
PlanInputSpec: object that describes plan input argument
type: 'data' | 'target' | 'epoch' | 'batchSize' | 'step' | 'modelParameter' | 'value'
index: number
name: (optional) string
value: (optional) tensor
PlanOutputSpec: object that describes plan output
type: 'loss' | 'metric' | 'modelParameter'
index: number
name: (optional) string
Pseudo code:
Training loop:
train(...):
x, y = get_batch(data, batchSize), get_batch(target, batchSize)
stepsPerEpoch = stepsPerEpoch || len(data) / batchSize
trigger_event('start')
modelParameters = job.model.parameters
for (i = 0; i < epochs; i++) {
trigger_event('epochStart', (i))
for (j = 0; j < stepsPerEpoch; j++) {
trigger_event('batchStart', (i, j))
plan_args = resolve_inputs(planInputs,
{
modelParameters: modelParameters,
data: x,
target: y,
epoch: i,
batchSize: batchSize,
step: j
}
)
raw_outputs = job.plans[trainingPlan].execute(...plan_args)
outputs = resolve_outputs(planOutputs, raw_outputs)
status = {loss: outputs.loss, metric: output.metric}
modelParameters = outputs.modelParameter
trigger_event('batchEnd', (i, j, status))
}
trigger_event('epochEnd', (i))
}
trigger_event('end')
Resolving plan inputs/outputs from specs:
resolve_inputs(specs, vars) {
args = []
for (spec in specs) {
if (spec.type == 'value') {
args.push(spec.value)
} elseif (spec.index) {
args.push(vars[spec.type][spec.index])
} else {
args.push(vars[spec.type])
}
}
return args
}
resolve_outputs(specs, output) {
out = {}
i = 0
for (spec in specs) {
if (spec.index) {
out[spec.type][index] = output[i]
} else {
out[spec.type] = output[i]
}
i++
}
return args
}
Example for input/output specs for MNIST training plan:
[{type: 'data'}, {type: 'target'}, {type: 'batchSize'}, {type: 'value', 'value': <lr>},
{type: 'modelParams', index: 0}, {type: 'modelParams', index: 1},
{type: 'modelParams', index: 2}, {type: 'modelParams', index: 3}]
[{type: 'loss'}, {type: 'metric'},
{type: 'modelParams', index: 0}, {type: 'modelParams', index: 1},
{type: 'modelParams', index: 2}, {type: 'modelParams', index: 3}]
What alternatives have you considered?
API was discussed in FL team.
Additional Context
n/a