diff --git a/lib/Service/ProvidersAI/TaskProcessingService.php b/lib/Service/ProvidersAI/TaskProcessingService.php index d0fbcdd4..24e5f5a3 100644 --- a/lib/Service/ProvidersAI/TaskProcessingService.php +++ b/lib/Service/ProvidersAI/TaskProcessingService.php @@ -103,6 +103,21 @@ private function everyElementHasKeys(array|null $array, array $keys): bool { return true; } + private function everyArrayElementHasKeys(array|null $array, array $keys): bool { + if (!is_array($array)) { + return false; + } + + foreach ($array as $element) { + foreach ($keys as $key) { + if (!array_key_exists($key, $element)) { + return false; + } + } + } + return true; + } + private function validateTaskProcessingProvider(array $provider): void { if (!isset($provider['id']) || !is_string($provider['id'])) { throw new Exception('"id" key must be a string'); @@ -116,10 +131,10 @@ private function validateTaskProcessingProvider(array $provider): void { if (!isset($provider['expected_runtime']) || !is_int($provider['expected_runtime'])) { throw new Exception('"expected_runtime" key must be an integer'); } - if (!$this->everyElementHasKeys($provider['optional_input_shape'], ['name', 'description', 'shape_type'])) { + if (!$this->everyArrayElementHasKeys($provider['optional_input_shape'], ['name', 'description', 'shape_type'])) { throw new Exception('"optional_input_shape" should be an array and must have "name", "description" and "shape_type" keys'); } - if (!$this->everyElementHasKeys($provider['optional_output_shape'], ['name', 'description', 'shape_type'])) { + if (!$this->everyArrayElementHasKeys($provider['optional_output_shape'], ['name', 'description', 'shape_type'])) { throw new Exception('"optional_output_shape" should be an array and must have "name", "description" and "shape_type" keys'); } if (!$this->everyElementHasKeys($provider['input_shape_enum_values'], ['name', 'value'])) { @@ -270,23 +285,25 @@ public function getExpectedRuntime(): int { } public function getOptionalInputShape(): array { - return array_map(function ($shape) { - return new ShapeDescriptor( + return array_reduce($this->provider['optional_input_shape'], function (array $input, array $shape) { + $input[$shape['name']] = new ShapeDescriptor( $shape['name'], $shape['description'], EShapeType::from($shape['shape_type']), ); - }, $this->provider['optional_input_shape']); + return $input; + }, []); } public function getOptionalOutputShape(): array { - return array_map(static function (array $shape) { - return new ShapeDescriptor( + return array_reduce($this->provider['optional_output_shape'], function (array $input, array $shape) { + $input[$shape['name']] = new ShapeDescriptor( $shape['name'], $shape['description'], EShapeType::from($shape['shape_type']), ); - }, $this->provider['optional_output_shape']); + return $input; + }, []); } public function getInputShapeEnumValues(): array {