From 32524ba2c4292ec06dcf40176050b9b43b168378 Mon Sep 17 00:00:00 2001 From: Alon Burg Date: Sun, 28 Jan 2024 09:20:22 +0200 Subject: [PATCH 1/4] astria api --- edenai_apis/apis/astria/astria_api.py | 82 ++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 15 deletions(-) diff --git a/edenai_apis/apis/astria/astria_api.py b/edenai_apis/apis/astria/astria_api.py index 88169eec0..d995bc186 100644 --- a/edenai_apis/apis/astria/astria_api.py +++ b/edenai_apis/apis/astria/astria_api.py @@ -12,6 +12,11 @@ AsyncBaseResponseType, ) +import requests + +def load_image(file_path): + with open(file_path, "rb") as f: + return f.read() class AstriaApi(ProviderInterface, ImageInterface): provider_name = "astria" @@ -26,30 +31,77 @@ def __init__(self, api_keys: Dict = {}) -> None: self.headers = {"authorization": f"Bearer {self.api_key}"} def image__generation_fine_tuning__create_project_async__launch_job( - self, - name: str, - description: str, - files: List[str], - files_url: List[str] = [], - base_project_id: Optional[int] = None, + self, + title: str, + class_name: str, + files: List[str] = [], + files_url: List[str] = [], + base_tune_id: Optional[int] = None, ) -> AsyncLaunchJobResponseType: - raise NotImplementedError + data = { + "tune[title]": title, + "tune[name]": class_name, + "tune[base_tune_id]": 690204, + # "tune[callback]": 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1' + } + for image in files: + image_data = load_image(image) # Assuming image is a file path + files.append(("tune[images][]", image_data)) + for image_url in files_url: + files.append(("tune[image_urls][]", image_url)) + + response = requests.post(f"{self.url}tunes", data=data, files=files, headers=self.headers) + response.raise_for_status() + return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"]) def image__generation_fine_tuning__create_project_async__get_job_result( self, provider_job_id: str ) -> AsyncBaseResponseType[GenerationFineTuningCreateProjectAsyncDataClass]: - raise NotImplementedError + response = requests.get(f"{self.url}tunes/{provider_job_id}", headers=self.headers) + response.raise_for_status() + data = response.json() + return AsyncBaseResponseType( + status="succeeded" if data['trained_at'] else "pending", + provider_job_id=provider_job_id, + original_response=data, + standardized_response=GenerationFineTuningCreateProjectAsyncDataClass(**data), + ) def image__generation_fine_tuning__generate_image_async__launch_job( - self, - project_id: str, - prompt: str, - negative_prompt: Optional[str] = "", - num_images: Optional[int] = 1, + self, + project_id: str, + prompt: str, + negative_prompt: Optional[str] = "", + num_images: Optional[int] = 1, + input_image: Optional[str] = None, ) -> AsyncLaunchJobResponseType: - raise NotImplementedError + data = { + 'prompt[text]': prompt, + 'prompt[negative_prompt]': negative_prompt, + 'prompt[num_images]': num_images, + 'prompt[face_swap]': True, + 'prompt[inpaint_faces]': False, + 'prompt[super_resolution]': True, + 'prompt[face_correct]': False, + # 'prompt[callback]': 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1' + } + files = [] + if input_image: + files.append((f"tune[prompts_attributes][{i}][input_image]", load_image(input_image))) + + response = requests.post(f"{self.url}/tunes/{project_id}", headers=self.headers, data=data) + response.raise_for_status() + return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"]) def image__generation_fine_tuning__generate_image_async__get_job_result( self, provider_job_id: str ) -> AsyncBaseResponseType[GenerationFineTuningGenerateImageAsyncDataClass]: - raise NotImplementedError + response = requests.get(f"{self.url}tunes/{provider_job_id}", headers=self.headers) + response.raise_for_status() + data = response.json() + return AsyncBaseResponseType( + status="succeeded" if data['trained_at'] else "pending", + provider_job_id=provider_job_id, + original_response=data, + standardized_response=GenerationFineTuningGenerateImageAsyncDataClass(**data), + ) From 44a597f2fe3e31760b7d887c6debbda74de4aa38 Mon Sep 17 00:00:00 2001 From: Alon Burg Date: Fri, 2 Feb 2024 11:17:22 +0200 Subject: [PATCH 2/4] astria api - base_tune_id --- edenai_apis/apis/astria/astria_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/edenai_apis/apis/astria/astria_api.py b/edenai_apis/apis/astria/astria_api.py index d995bc186..19caf5601 100644 --- a/edenai_apis/apis/astria/astria_api.py +++ b/edenai_apis/apis/astria/astria_api.py @@ -41,7 +41,7 @@ def image__generation_fine_tuning__create_project_async__launch_job( data = { "tune[title]": title, "tune[name]": class_name, - "tune[base_tune_id]": 690204, + "tune[base_tune_id]": base_tune_id, # "tune[callback]": 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1' } for image in files: From bc9e854a7da7cfec760dda2a3bf856f6c84b9319 Mon Sep 17 00:00:00 2001 From: Alon Burg Date: Fri, 2 Feb 2024 11:19:01 +0200 Subject: [PATCH 3/4] astria api - GenerationFineTuningCreateProjectAsyncDataClass --- edenai_apis/apis/astria/astria_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/edenai_apis/apis/astria/astria_api.py b/edenai_apis/apis/astria/astria_api.py index 19caf5601..f7685a82d 100644 --- a/edenai_apis/apis/astria/astria_api.py +++ b/edenai_apis/apis/astria/astria_api.py @@ -64,7 +64,11 @@ def image__generation_fine_tuning__create_project_async__get_job_result( status="succeeded" if data['trained_at'] else "pending", provider_job_id=provider_job_id, original_response=data, - standardized_response=GenerationFineTuningCreateProjectAsyncDataClass(**data), + standardized_response=GenerationFineTuningCreateProjectAsyncDataClass( + project_id=data["id"], + name=data["name"], + description=data["title"], + ), ) def image__generation_fine_tuning__generate_image_async__launch_job( From 5aa343fcc3360597f93e72818303fa9665580829 Mon Sep 17 00:00:00 2001 From: Alon Burg Date: Fri, 2 Feb 2024 14:03:49 +0200 Subject: [PATCH 4/4] astria api - PR fixes --- edenai_apis/apis/astria/astria_api.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/edenai_apis/apis/astria/astria_api.py b/edenai_apis/apis/astria/astria_api.py index f7685a82d..54218cf84 100644 --- a/edenai_apis/apis/astria/astria_api.py +++ b/edenai_apis/apis/astria/astria_api.py @@ -52,7 +52,7 @@ def image__generation_fine_tuning__create_project_async__launch_job( response = requests.post(f"{self.url}tunes", data=data, files=files, headers=self.headers) response.raise_for_status() - return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"]) + return AsyncLaunchJobResponseType(provider_job_id=str(response.json()["id"])) def image__generation_fine_tuning__create_project_async__get_job_result( self, provider_job_id: str @@ -71,6 +71,7 @@ def image__generation_fine_tuning__create_project_async__get_job_result( ), ) + # https://docs.astria.ai/docs/api/prompt/create def image__generation_fine_tuning__generate_image_async__launch_job( self, project_id: str, @@ -78,25 +79,31 @@ def image__generation_fine_tuning__generate_image_async__launch_job( negative_prompt: Optional[str] = "", num_images: Optional[int] = 1, input_image: Optional[str] = None, + # Only if name=man/woman + face_swap: Optional[bool] = True, + inpaint_faces: Optional[bool] = True, + super_resolution: Optional[bool] = True, + face_correct: Optional[bool] = False, ) -> AsyncLaunchJobResponseType: data = { 'prompt[text]': prompt, 'prompt[negative_prompt]': negative_prompt, 'prompt[num_images]': num_images, - 'prompt[face_swap]': True, - 'prompt[inpaint_faces]': False, - 'prompt[super_resolution]': True, - 'prompt[face_correct]': False, + 'prompt[face_swap]': face_swap, + 'prompt[inpaint_faces]': inpaint_faces, + 'prompt[super_resolution]': super_resolution, + 'prompt[face_correct]': face_correct, # 'prompt[callback]': 'https://optional-callback-url.com/to-your-service-when-ready?prompt_id=1' } files = [] if input_image: - files.append((f"tune[prompts_attributes][{i}][input_image]", load_image(input_image))) + files.append((f"prompt[input_image]", load_image(input_image))) - response = requests.post(f"{self.url}/tunes/{project_id}", headers=self.headers, data=data) + response = requests.post(f"{self.url}/tunes/{project_id}/prompts", headers=self.headers, data=data, files=files) response.raise_for_status() return AsyncLaunchJobResponseType(provider_job_id=response.json()["id"]) + # https://docs.astria.ai/docs/api/prompt/retrieve def image__generation_fine_tuning__generate_image_async__get_job_result( self, provider_job_id: str ) -> AsyncBaseResponseType[GenerationFineTuningGenerateImageAsyncDataClass]: @@ -107,5 +114,5 @@ def image__generation_fine_tuning__generate_image_async__get_job_result( status="succeeded" if data['trained_at'] else "pending", provider_job_id=provider_job_id, original_response=data, - standardized_response=GenerationFineTuningGenerateImageAsyncDataClass(**data), + standardized_response=GenerationFineTuningGenerateImageAsyncDataClass(images_url=data["images"]), )