Skip to content

Commit 9049405

Browse files
committed
Initial commit
Use pickle for model2wasm Fix TinyGo entrypoint and m2cgen wrapping workflow Add FL task mode and metadata fields Signed-off-by: JeffMboya <jangina.mboya@gmail.com> Include FL mode and spec in start payload Signed-off-by: JeffMboya <jangina.mboya@gmail.com> Add FL update envelope type for results Signed-off-by: JeffMboya <jangina.mboya@gmail.com> Deduplicate FL helpers; update host and wazero Signed-off-by: JeffMboya <jangina.mboya@gmail.com> Handle train results as FL envelope updates Signed-off-by: JeffMboya <jangina.mboya@gmail.com>
1 parent 22541f0 commit 9049405

File tree

13 files changed

+919
-0
lines changed

13 files changed

+919
-0
lines changed

examples/ml-wasm/main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package main
2+
3+
func predict(x0 int32, x1 int32) int32 {
4+
in := []float64{
5+
float64(x0) / 100.0,
6+
float64(x1) / 100.0,
7+
}
8+
9+
y := score(in)
10+
11+
return int32(y * 100.0)
12+
}
13+
14+
func main() {}

examples/ml-wasm/model_gen.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package main
2+
3+
func score(input []float64) float64 {
4+
return 0.0000000000000008881784197001252 + input[0] * 1.9999999999999996 + input[1] * 2.999999999999999
5+
}
6+

examples/ml-wasm/mymodel.pkl

451 Bytes
Binary file not shown.

examples/ml-wasm/mymodel.wasm

93.1 KB
Binary file not shown.

examples/ml-wasm/train_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from sklearn.linear_model import LinearRegression
2+
import numpy as np
3+
import os
4+
import pickle
5+
6+
script_dir = os.path.dirname(os.path.abspath(__file__))
7+
8+
model_path = os.path.join(script_dir, "mymodel.pkl")
9+
10+
# Fake training data: y = 2*x1 + 3*x2
11+
X = np.array([
12+
[0, 0],
13+
[1, 0],
14+
[0, 1],
15+
[1, 1],
16+
[2, 1],
17+
[1, 2],
18+
], dtype=float)
19+
20+
y = 2 * X[:, 0] + 3 * X[:, 1]
21+
22+
model = LinearRegression()
23+
model.fit(X, y)
24+
25+
with open(model_path, "wb") as f:
26+
pickle.dump(model, f)
27+
28+
print(f"Saved model to {model_path}")

manager/service.go

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package manager
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
78
"log/slog"
89
"time"
910

1011
"github.com/0x6flab/namegenerator"
1112
pkgerrors "github.com/absmach/propeller/pkg/errors"
13+
flpkg "github.com/absmach/propeller/pkg/fl"
1214
"github.com/absmach/propeller/pkg/mqtt"
1315
"github.com/absmach/propeller/pkg/proplet"
1416
"github.com/absmach/propeller/pkg/scheduler"
@@ -435,6 +437,70 @@ func (svc *service) updateResultsHandler(ctx context.Context, msg map[string]any
435437
t.UpdatedAt = time.Now()
436438
t.FinishTime = time.Now()
437439

440+
if errMsg, ok := msg["error"].(string); ok && errMsg != "" {
441+
t.Error = errMsg
442+
}
443+
444+
if err := svc.tasksDB.Update(ctx, t.ID, t); err != nil {
445+
return err
446+
}
447+
448+
const scanLimit = 10_000
449+
data, _, lErr := svc.tasksDB.List(ctx, 0, scanLimit)
450+
if lErr != nil {
451+
return lErr
452+
}
453+
454+
var (
455+
pendingOrRunning bool
456+
numUpdates uint64
457+
totalSamples uint64
458+
)
459+
460+
for i := range data {
461+
tt, ok := data[i].(task.Task)
462+
if !ok {
463+
return pkgerrors.ErrInvalidData
464+
}
465+
if tt.Mode != task.ModeTrain || tt.FL == nil {
466+
continue
467+
}
468+
if tt.FL.JobID != envlp.JobID || tt.FL.RoundID != envlp.RoundID {
469+
continue
470+
}
471+
472+
switch tt.State {
473+
case task.Pending, task.Scheduled, task.Running:
474+
pendingOrRunning = true
475+
}
476+
477+
if tt.State == task.Completed && tt.Error == "" {
478+
if u, ok := tt.Results.(flpkg.UpdateEnvelope); ok {
479+
numUpdates++
480+
totalSamples += u.NumSamples
481+
} else {
482+
// If storage rehydrates as map[string]any, ignore for aggregation stub.
483+
}
484+
}
485+
}
486+
487+
if !pendingOrRunning {
488+
svc.logger.InfoContext(ctx, "FL round complete (aggregation stub)",
489+
slog.String("job_id", envlp.JobID),
490+
slog.Uint64("round_id", envlp.RoundID),
491+
slog.Uint64("num_updates", numUpdates),
492+
slog.Uint64("total_samples", totalSamples),
493+
)
494+
}
495+
496+
return nil
497+
}
498+
499+
t.Results = msg["results"]
500+
t.State = task.Completed
501+
t.UpdatedAt = time.Now()
502+
t.FinishTime = time.Now()
503+
438504
if errMsg, ok := msg["error"].(string); ok && errMsg != "" {
439505
t.Error = errMsg
440506
}
@@ -534,6 +600,94 @@ func (svc *service) handlePropletMetrics(ctx context.Context, msg map[string]any
534600
return nil
535601
}
536602

603+
func (svc *service) handleTaskMetrics(ctx context.Context, msg map[string]any) error {
604+
taskID, ok := msg["task_id"].(string)
605+
if !ok {
606+
return errors.New("invalid task_id")
607+
}
608+
if taskID == "" {
609+
return errors.New("task id is empty")
610+
}
611+
612+
propletID, ok := msg["proplet_id"].(string)
613+
if !ok {
614+
return errors.New("invalid proplet_id")
615+
}
616+
617+
taskMetrics := TaskMetrics{
618+
TaskID: taskID,
619+
PropletID: propletID,
620+
}
621+
622+
if ts, ok := msg["timestamp"].(string); ok {
623+
if t, err := time.Parse(time.RFC3339Nano, ts); err == nil {
624+
taskMetrics.Timestamp = t
625+
}
626+
}
627+
if taskMetrics.Timestamp.IsZero() {
628+
taskMetrics.Timestamp = time.Now()
629+
}
630+
631+
if metricsData, ok := msg["metrics"].(map[string]any); ok {
632+
taskMetrics.Metrics = svc.parseProcessMetrics(metricsData)
633+
}
634+
635+
if aggData, ok := msg["aggregated"].(map[string]any); ok {
636+
taskMetrics.Aggregated = svc.parseAggregatedMetrics(aggData)
637+
}
638+
639+
key := fmt.Sprintf("%s:%d", taskID, taskMetrics.Timestamp.UnixNano())
640+
if err := svc.metricsDB.Create(ctx, key, taskMetrics); err != nil {
641+
svc.logger.WarnContext(ctx, "failed to store task metrics", "error", err, "task_id", taskID)
642+
643+
return err
644+
}
645+
646+
return nil
647+
}
648+
649+
func (svc *service) handlePropletMetrics(ctx context.Context, msg map[string]any) error {
650+
propletID, ok := msg["proplet_id"].(string)
651+
if !ok {
652+
return errors.New("invalid proplet_id")
653+
}
654+
if propletID == "" {
655+
return errors.New("proplet id is empty")
656+
}
657+
namespace, _ := msg["namespace"].(string)
658+
659+
propletMetrics := PropletMetrics{
660+
PropletID: propletID,
661+
Namespace: namespace,
662+
}
663+
664+
if ts, ok := msg["timestamp"].(string); ok {
665+
if t, err := time.Parse(time.RFC3339Nano, ts); err == nil {
666+
propletMetrics.Timestamp = t
667+
}
668+
}
669+
if propletMetrics.Timestamp.IsZero() {
670+
propletMetrics.Timestamp = time.Now()
671+
}
672+
673+
if cpuData, ok := msg["cpu_metrics"].(map[string]any); ok {
674+
propletMetrics.CPU = svc.parseCPUMetrics(cpuData)
675+
}
676+
677+
if memData, ok := msg["memory_metrics"].(map[string]any); ok {
678+
propletMetrics.Memory = svc.parseMemoryMetrics(memData)
679+
}
680+
681+
key := fmt.Sprintf("%s:%d", propletID, propletMetrics.Timestamp.UnixNano())
682+
if err := svc.metricsDB.Create(ctx, key, propletMetrics); err != nil {
683+
svc.logger.WarnContext(ctx, "failed to store proplet metrics", "error", err, "proplet_id", propletID)
684+
685+
return err
686+
}
687+
688+
return nil
689+
}
690+
537691
func (svc *service) parseProcessMetrics(data map[string]any) proplet.ProcessMetrics {
538692
metrics := proplet.ProcessMetrics{}
539693

pkg/fl/types.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package fl
2+
3+
type UpdateEnvelope struct {
4+
TaskID string `json:"task_id,omitempty"`
5+
JobID string `json:"job_id"`
6+
RoundID uint64 `json:"round_id"`
7+
GlobalVersion string `json:"global_version"`
8+
PropletID string `json:"proplet_id"`
9+
NumSamples uint64 `json:"num_samples"`
10+
UpdateB64 string `json:"update_b64"`
11+
Metrics map[string]any `json:"metrics,omitempty"`
12+
Format string `json:"format,omitempty"`
13+
}

proplet/runtimes/fl_helpers.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package runtimes
2+
3+
import (
4+
"encoding/base64"
5+
"fmt"
6+
"strconv"
7+
8+
flpkg "github.com/absmach/propeller/pkg/fl"
9+
)
10+
11+
func buildFLPayloadFromString(taskID, mode, propletID string, env map[string]string, rawOut string) map[string]any {
12+
// Backward-compatible inference behavior.
13+
if mode != "train" {
14+
return map[string]any{
15+
"task_id": taskID,
16+
"results": rawOut,
17+
}
18+
}
19+
20+
envlp := flpkg.UpdateEnvelope{
21+
TaskID: taskID,
22+
JobID: env["FL_JOB_ID"],
23+
RoundID: parseUint64(env, "FL_ROUND_ID"),
24+
GlobalVersion: env["FL_GLOBAL_VERSION"],
25+
PropletID: propletID,
26+
NumSamples: parseUint64(env, "FL_NUM_SAMPLES"),
27+
UpdateB64: base64.StdEncoding.EncodeToString([]byte(rawOut)),
28+
Metrics: nil,
29+
Format: env["FL_FORMAT"],
30+
}
31+
32+
return map[string]any{
33+
"task_id": taskID,
34+
"results": envlp,
35+
}
36+
}
37+
38+
func buildFLPayloadFromUint64Slice(taskID, mode, propletID string, env map[string]string, results []uint64) map[string]any {
39+
if mode != "train" {
40+
return map[string]any{
41+
"task_id": taskID,
42+
"results": results,
43+
}
44+
}
45+
46+
raw := fmt.Sprint(results)
47+
48+
envlp := flpkg.UpdateEnvelope{
49+
TaskID: taskID,
50+
JobID: env["FL_JOB_ID"],
51+
RoundID: parseUint64(env, "FL_ROUND_ID"),
52+
GlobalVersion: env["FL_GLOBAL_VERSION"],
53+
PropletID: propletID,
54+
NumSamples: parseUint64(env, "FL_NUM_SAMPLES"),
55+
UpdateB64: base64.StdEncoding.EncodeToString([]byte(raw)),
56+
Metrics: nil,
57+
Format: env["FL_FORMAT"],
58+
}
59+
60+
return map[string]any{
61+
"task_id": taskID,
62+
"results": envlp,
63+
}
64+
}
65+
66+
func parseUint64(env map[string]string, key string) uint64 {
67+
if env == nil {
68+
return 0
69+
}
70+
s, ok := env[key]
71+
if !ok || s == "" {
72+
return 0
73+
}
74+
v, err := strconv.ParseUint(s, 10, 64)
75+
if err != nil {
76+
return 0
77+
}
78+
return v
79+
}

0 commit comments

Comments
 (0)