@@ -39,73 +39,67 @@ func (s *composeService) ensureModels(ctx context.Context, project *types.Projec
3939 return nil
4040 }
4141
42- dockerModel , err := manager .GetPlugin ("model" , s .dockerCli , & cobra.Command {})
43- if err != nil {
44- if errdefs .IsNotFound (err ) {
45- return fmt .Errorf ("'models' support requires Docker Model plugin" )
46- }
47- return err
48- }
49-
50- cmd := exec .CommandContext (ctx , dockerModel .Path , "ls" , "--json" )
51- err = s .prepareShellOut (ctx , project , cmd )
42+ api , err := s .newModelAPI (project )
5243 if err != nil {
5344 return err
5445 }
46+ availableModels , err := api .ListModels (ctx )
5547
56- output , err := cmd .CombinedOutput ()
57- if err != nil {
58- return fmt .Errorf ("error checking available models: %w" , err )
59- }
60-
61- type AvailableModel struct {
62- Id string `json:"id"`
63- Tags []string `json:"tags"`
64- Created int `json:"created"`
65- }
66-
67- models := []AvailableModel {}
68- err = json .Unmarshal (output , & models )
69- if err != nil {
70- return fmt .Errorf ("error unmarshalling available models: %w" , err )
71- }
72- var availableModels []string
73- for _ , model := range models {
74- availableModels = append (availableModels , model .Tags ... )
75- }
76-
77- eg , gctx := errgroup .WithContext (ctx )
48+ eg , ctx := errgroup .WithContext (ctx )
7849 eg .Go (func () error {
79- return s . setModelVariables ( gctx , dockerModel , project )
50+ return api . SetModelVariables ( ctx , project )
8051 })
8152
53+ w := progress .ContextWriter (ctx )
8254 for name , config := range project .Models {
8355 if config .Name == "" {
8456 config .Name = name
8557 }
8658 eg .Go (func () error {
87- w := progress .ContextWriter (gctx )
8859 if ! slices .Contains (availableModels , config .Model ) {
89- err = s . pullModel ( gctx , dockerModel , project , config , quietPull , w )
60+ err = api . PullModel ( ctx , config , quietPull , w )
9061 if err != nil {
9162 return err
9263 }
9364 }
94- return s . configureModel ( gctx , dockerModel , project , config , w )
65+ return api . ConfigureModel ( ctx , config , w )
9566 })
9667 }
9768 return eg .Wait ()
9869}
9970
100- func (s * composeService ) pullModel (ctx context.Context , dockerModel * manager.Plugin , project * types.Project , model types.ModelConfig , quietPull bool , w progress.Writer ) error {
71+ type modelAPI struct {
72+ path string
73+ env []string
74+ prepare func (ctx context.Context , cmd * exec.Cmd ) error
75+ }
76+
77+ func (s * composeService ) newModelAPI (project * types.Project ) (* modelAPI , error ) {
78+ dockerModel , err := manager .GetPlugin ("model" , s .dockerCli , & cobra.Command {})
79+ if err != nil {
80+ if errdefs .IsNotFound (err ) {
81+ return nil , fmt .Errorf ("'models' support requires Docker Model plugin" )
82+ }
83+ return nil , err
84+ }
85+ return & modelAPI {
86+ path : dockerModel .Path ,
87+ prepare : func (ctx context.Context , cmd * exec.Cmd ) error {
88+ return s .prepareShellOut (ctx , project .Environment , cmd )
89+ },
90+ env : project .Environment .Values (),
91+ }, nil
92+ }
93+
94+ func (m * modelAPI ) PullModel (ctx context.Context , model types.ModelConfig , quietPull bool , w progress.Writer ) error {
10195 w .Event (progress.Event {
10296 ID : model .Name ,
10397 Status : progress .Working ,
10498 Text : "Pulling" ,
10599 })
106100
107- cmd := exec .CommandContext (ctx , dockerModel . Path , "pull" , model .Model )
108- err := s . prepareShellOut (ctx , project , cmd )
101+ cmd := exec .CommandContext (ctx , m . path , "pull" , model .Model )
102+ err := m . prepare (ctx , cmd )
109103 if err != nil {
110104 return err
111105 }
@@ -148,7 +142,7 @@ func (s *composeService) pullModel(ctx context.Context, dockerModel *manager.Plu
148142 return err
149143}
150144
151- func (s * composeService ) configureModel (ctx context.Context , dockerModel * manager. Plugin , project * types. Project , config types.ModelConfig , w progress.Writer ) error {
145+ func (m * modelAPI ) ConfigureModel (ctx context.Context , config types.ModelConfig , w progress.Writer ) error {
152146 w .Event (progress.Event {
153147 ID : config .Name ,
154148 Status : progress .Working ,
@@ -164,17 +158,17 @@ func (s *composeService) configureModel(ctx context.Context, dockerModel *manage
164158 args = append (args , "--" )
165159 args = append (args , config .RuntimeFlags ... )
166160 }
167- cmd := exec .CommandContext (ctx , dockerModel . Path , args ... )
168- err := s . prepareShellOut (ctx , project , cmd )
161+ cmd := exec .CommandContext (ctx , m . path , args ... )
162+ err := m . prepare (ctx , cmd )
169163 if err != nil {
170164 return err
171165 }
172166 return cmd .Run ()
173167}
174168
175- func (s * composeService ) setModelVariables (ctx context.Context , dockerModel * manager. Plugin , project * types.Project ) error {
176- cmd := exec .CommandContext (ctx , dockerModel . Path , "status" , "--json" )
177- err := s . prepareShellOut (ctx , project , cmd )
169+ func (m * modelAPI ) SetModelVariables (ctx context.Context , project * types.Project ) error {
170+ cmd := exec .CommandContext (ctx , m . path , "status" , "--json" )
171+ err := m . prepare (ctx , cmd )
178172 if err != nil {
179173 return err
180174 }
@@ -228,3 +222,33 @@ type Model struct {
228222 Size string `json:"size"`
229223 } `json:"config"`
230224}
225+
226+ func (m * modelAPI ) ListModels (ctx context.Context ) ([]string , error ) {
227+ cmd := exec .CommandContext (ctx , m .path , "ls" , "--json" )
228+ err := m .prepare (ctx , cmd )
229+ if err != nil {
230+ return nil , err
231+ }
232+
233+ output , err := cmd .CombinedOutput ()
234+ if err != nil {
235+ return nil , fmt .Errorf ("error checking available models: %w" , err )
236+ }
237+
238+ type AvailableModel struct {
239+ Id string `json:"id"`
240+ Tags []string `json:"tags"`
241+ Created int `json:"created"`
242+ }
243+
244+ models := []AvailableModel {}
245+ err = json .Unmarshal (output , & models )
246+ if err != nil {
247+ return nil , fmt .Errorf ("error unmarshalling available models: %w" , err )
248+ }
249+ var availableModels []string
250+ for _ , model := range models {
251+ availableModels = append (availableModels , model .Tags ... )
252+ }
253+ return availableModels , nil
254+ }
0 commit comments