@@ -84,7 +84,7 @@ public static function fromPretrained(
8484 string $ revision = 'main ' ,
8585 ?string $ modelFilename = null ,
8686 ModelArchitecture $ modelArchitecture = ModelArchitecture::EncoderOnly,
87- ?callable $ onProgress = null
87+ ?callable $ onProgress = null
8888 ): self
8989 {
9090 if (is_array ($ config )) {
@@ -115,7 +115,12 @@ public static function fromPretrained(
115115
116116 $ generatorConfig = new GenerationConfig ($ generatorConfigArr );
117117
118- return new static ($ config , $ session , $ modelArchitecture , $ generatorConfig );
118+ return new static (
119+ config: $ config ,
120+ session: $ session ,
121+ modelArchitecture: $ modelArchitecture ,
122+ generationConfig: $ generatorConfig
123+ );
119124 }
120125
121126 case ModelArchitecture::Seq2SeqLM:
@@ -148,8 +153,13 @@ public static function fromPretrained(
148153
149154 $ generatorConfig = new GenerationConfig ($ generatorConfigArr );
150155
151-
152- return new static ($ config , $ encoderSession , $ decoderSession , $ modelArchitecture , $ generatorConfig );
156+ return new static (
157+ config: $ config ,
158+ session: $ encoderSession ,
159+ modelArchitecture: $ modelArchitecture ,
160+ generationConfig: $ generatorConfig ,
161+ decoderMergedSession: $ decoderSession
162+ );
153163 }
154164
155165 case ModelArchitecture::MaskGeneration:
@@ -170,7 +180,12 @@ public static function fromPretrained(
170180 onProgress: $ onProgress
171181 );
172182
173- return new static ($ config , $ visionEncoder , $ promptMaskEncoder , $ modelArchitecture );
183+ return new static (
184+ config: $ config ,
185+ session: $ visionEncoder ,
186+ promptMaskEncoderSession: $ promptMaskEncoder ,
187+ modelArchitecture: $ modelArchitecture
188+ );
174189 }
175190
176191 case ModelArchitecture::EncoderDecoder:
@@ -191,7 +206,12 @@ public static function fromPretrained(
191206 onProgress: $ onProgress
192207 );
193208
194- return new static ($ config , $ encoderSession , $ decoderSession , $ modelArchitecture );
209+ return new static (
210+ config: $ config ,
211+ session: $ encoderSession ,
212+ decoderMergedSession: $ decoderSession ,
213+ modelArchitecture: $ modelArchitecture
214+ );
195215 }
196216
197217 default :
@@ -210,7 +230,11 @@ public static function fromPretrained(
210230 );
211231
212232
213- return new static ($ config , $ session , $ modelArchitecture );
233+ return new static (
234+ config: $ config ,
235+ session: $ session ,
236+ modelArchitecture: $ modelArchitecture
237+ );
214238 }
215239 }
216240 }
@@ -232,14 +256,14 @@ public static function fromPretrained(
232256 */
233257
234258 public static function constructSession (
235- string $ modelNameOrPath ,
236- string $ fileName ,
237- ?string $ cacheDir = null ,
238- string $ revision = 'main ' ,
239- string $ subFolder = 'onnx ' ,
240- bool $ fatal = true ,
259+ string $ modelNameOrPath ,
260+ string $ fileName ,
261+ ?string $ cacheDir = null ,
262+ string $ revision = 'main ' ,
263+ string $ subFolder = 'onnx ' ,
264+ bool $ fatal = true ,
241265 ?callable $ onProgress = null ,
242- ...$ sessionOptions
266+ ...$ sessionOptions
243267 ): ?InferenceSession
244268 {
245269 $ modelFileName = "$ fileName.onnx " ;
0 commit comments