Skip to content

Commit bd04ff7

Browse files
committed
Add new options to generate method and CLI
1 parent ceb42ad commit bd04ff7

File tree

4 files changed

+263
-8
lines changed

4 files changed

+263
-8
lines changed

firefly/cli.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,15 @@ def generate(
101101
envvar=["FIREFLY_CLIENT_SECRET"],
102102
),
103103
prompt: str = typer.Option(..., help="Text prompt for image generation"),
104+
num_variations: int = typer.Option(None, help="Number of images to generate (numVariations)"),
105+
style: str = typer.Option(None, help="Style object as JSON string (presets, imageReference, strength, etc.)"),
106+
structure: str = typer.Option(None, help="Structure object as JSON string (strength, imageReference, etc.)"),
107+
prompt_biasing_locale_code: str = typer.Option(None, help="Locale code for prompt biasing (promptBiasingLocaleCode)"),
108+
negative_prompt: str = typer.Option(None, help="Negative prompt to avoid certain content"),
109+
seed: int = typer.Option(None, help="Seed for deterministic output"),
110+
aspect_ratio: str = typer.Option(None, help="Aspect ratio, e.g. '1:1', '16:9'"),
111+
output_format: str = typer.Option(None, help="Output format, e.g. 'jpeg', 'png'"),
112+
content_class: str = typer.Option(None, help="Content class: 'photo' or 'art'"),
104113
download: bool = typer.Option(False, help="Download the generated image to a file (filename is taken from the image URL)"),
105114
show_images: bool = typer.Option(False, help="Display the image in the terminal after download."),
106115
use_mocks: bool = typer.Option(False, help="Mock API responses for testing without a valid client secret."),
@@ -114,11 +123,39 @@ def generate(
114123
raise typer.BadParameter("client_id must be provided as an option or via the FIREFLY_CLIENT_ID environment variable.")
115124
if not client_secret:
116125
raise typer.BadParameter("client_secret must be provided as an option or via the FIREFLY_CLIENT_SECRET environment variable.")
117-
with with_maybe_use_mocks(use_mocks):
118-
_generate(client_id, client_secret, prompt, download, show_images, format, verbose)
119-
120-
121-
def _generate(client_id, client_secret, prompt, download, show_images, format, verbose):
126+
# Parse JSON for style/structure if provided
127+
style_obj = None
128+
if style:
129+
try:
130+
style_obj = json.loads(style)
131+
except Exception as e:
132+
raise typer.BadParameter(f"Invalid JSON for --style: {e}")
133+
structure_obj = None
134+
if structure:
135+
try:
136+
structure_obj = json.loads(structure)
137+
except Exception as e:
138+
raise typer.BadParameter(f"Invalid JSON for --structure: {e}")
139+
try:
140+
with with_maybe_use_mocks(use_mocks):
141+
_generate(
142+
client_id, client_secret, prompt, download, show_images, format, verbose,
143+
num_variations=num_variations,
144+
style=style_obj,
145+
structure=structure_obj,
146+
prompt_biasing_locale_code=prompt_biasing_locale_code,
147+
negative_prompt=negative_prompt,
148+
seed=seed,
149+
aspect_ratio=aspect_ratio,
150+
output_format=output_format,
151+
content_class=content_class,
152+
)
153+
except ValueError as e:
154+
typer.secho(str(e), fg=typer.colors.RED, err=True)
155+
raise typer.Exit(code=-1)
156+
157+
158+
def _generate(client_id, client_secret, prompt, download, show_images, format, verbose, **kwargs):
122159
client = FireflyClient(client_id=client_id, client_secret=client_secret)
123160
image_api_url = client.BASE_URL
124161
if verbose:
@@ -140,7 +177,7 @@ def capture_request(*args, **kwargs):
140177
_requests.request = capture_request
141178

142179
try:
143-
response = client.generate_image(prompt=prompt)
180+
response = client.generate_image(prompt=prompt, **kwargs)
144181
finally:
145182
_requests.request = orig_request
146183

firefly/client.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,35 @@ def _request(
6363
except Exception as e:
6464
raise FireflyAPIError(f"Request failed: {e}")
6565

66-
def generate_image(self, prompt: str, **kwargs) -> FireflyImageResponse:
66+
def generate_image(
67+
self,
68+
prompt: str,
69+
num_variations: Optional[int] = None,
70+
style: Optional[dict] = None,
71+
structure: Optional[dict] = None,
72+
prompt_biasing_locale_code: Optional[str] = None,
73+
negative_prompt: Optional[str] = None,
74+
seed: Optional[int] = None,
75+
aspect_ratio: Optional[str] = None,
76+
output_format: Optional[str] = None,
77+
content_class: Optional[str] = None, # 'photo' or 'art'
78+
**kwargs
79+
) -> FireflyImageResponse:
6780
"""
6881
Generate an image from a text prompt using Adobe Firefly.
6982
7083
Args:
7184
prompt (str): The text prompt for image generation.
72-
**kwargs: Additional parameters for the API (e.g., style, aspect ratio).
85+
num_variations (int, optional): Number of images to generate (was `n`).
86+
style (dict, optional): Style object, e.g. {"presets": [...], "imageReference": {...}, "strength": ...}.
87+
structure (dict, optional): Structure reference object, e.g. {"strength": ..., "imageReference": {...}}.
88+
prompt_biasing_locale_code (str, optional): Locale code for prompt biasing (was `locale`).
89+
negative_prompt (str, optional): Negative prompt to avoid certain content.
90+
seed (int, optional): Seed for deterministic output.
91+
aspect_ratio (str, optional): Aspect ratio, e.g. "1:1", "16:9".
92+
output_format (str, optional): Output format, e.g. "jpeg", "png".
93+
content_class (str, optional): Content class, either 'photo' or 'art'.
94+
**kwargs: Additional parameters for the API.
7395
7496
Returns:
7597
FireflyImageResponse: The response object containing all fields from the API response.
@@ -78,6 +100,26 @@ def generate_image(self, prompt: str, **kwargs) -> FireflyImageResponse:
78100
FireflyAPIError, FireflyAuthError
79101
"""
80102
data = {"prompt": prompt}
103+
if num_variations is not None:
104+
data["numVariations"] = num_variations
105+
if style is not None:
106+
data["style"] = style
107+
if structure is not None:
108+
data["structure"] = structure
109+
if prompt_biasing_locale_code is not None:
110+
data["promptBiasingLocaleCode"] = prompt_biasing_locale_code
111+
if negative_prompt is not None:
112+
data["negativePrompt"] = negative_prompt
113+
if seed is not None:
114+
data["seed"] = seed
115+
if aspect_ratio is not None:
116+
data["aspectRatio"] = aspect_ratio
117+
if output_format is not None:
118+
data["outputFormat"] = output_format
119+
if content_class is not None:
120+
if content_class not in ("photo", "art"):
121+
raise ValueError("content_class must be either 'photo' or 'art'")
122+
data["contentClass"] = content_class
81123
data.update(kwargs)
82124
resp = self._request(method="POST", url=self.BASE_URL, json=data)
83125
try:

tests/test_cli.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,85 @@ def test_generate_verbose(monkeypatch):
137137
# Should include verbose status messages
138138
assert "Doing request to" in result.output
139139
assert "Received HTTP" in result.output
140+
141+
def test_generate_with_all_new_options(monkeypatch):
142+
# Valid content_class: photo
143+
result = runner.invoke(
144+
app,
145+
[
146+
"image", "generate",
147+
"--client-id", "dummy_id",
148+
"--client-secret", "dummy_secret",
149+
"--prompt", "test",
150+
"--num-variations", "2",
151+
"--style", '{"presets": ["bw"], "strength": 50}',
152+
"--structure", '{"strength": 80, "imageReference": {"source": {"uploadId": "abc123"}}}',
153+
"--prompt-biasing-locale-code", "en-US",
154+
"--negative-prompt", "no text",
155+
"--seed", "42",
156+
"--aspect-ratio", "16:9",
157+
"--output-format", "jpeg",
158+
"--content-class", "photo",
159+
"--use-mocks"
160+
]
161+
)
162+
assert result.exit_code == 0
163+
assert "Generated image URL:" in result.output
164+
# Valid content_class: art
165+
result = runner.invoke(
166+
app,
167+
[
168+
"image", "generate",
169+
"--client-id", "dummy_id",
170+
"--client-secret", "dummy_secret",
171+
"--prompt", "test",
172+
"--content-class", "art",
173+
"--use-mocks"
174+
]
175+
)
176+
assert result.exit_code == 0
177+
assert "Generated image URL:" in result.output
178+
# Invalid content_class
179+
result = runner.invoke(
180+
app,
181+
[
182+
"image", "generate",
183+
"--client-id", "dummy_id",
184+
"--client-secret", "dummy_secret",
185+
"--prompt", "test",
186+
"--content-class", "invalid",
187+
"--use-mocks"
188+
]
189+
)
190+
assert result.exit_code != 0
191+
assert "content_class must be either 'photo' or 'art'" in result.output
192+
193+
def test_generate_invalid_json_style(monkeypatch):
194+
result = runner.invoke(
195+
app,
196+
[
197+
"image", "generate",
198+
"--client-id", "dummy_id",
199+
"--client-secret", "dummy_secret",
200+
"--prompt", "test",
201+
"--style", "not-a-json",
202+
"--use-mocks"
203+
]
204+
)
205+
assert result.exit_code == 2
206+
assert "Invalid JSON for --style" in result.output
207+
208+
def test_generate_invalid_json_structure(monkeypatch):
209+
result = runner.invoke(
210+
app,
211+
[
212+
"image", "generate",
213+
"--client-id", "dummy_id",
214+
"--client-secret", "dummy_secret",
215+
"--prompt", "test",
216+
"--structure", "not-a-json",
217+
"--use-mocks"
218+
]
219+
)
220+
assert result.exit_code == 2
221+
assert "Invalid JSON for --structure" in result.output

tests/test_firefly_client.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,97 @@ def test_generate_image_value_error(client, mock_valid_ims_access_token_response
158158
with mock.patch("requests.request", side_effect=ValueError("bad value")):
159159
with pytest.raises(FireflyAPIError):
160160
client.generate_image(prompt="trigger value error")
161+
162+
163+
@responses.activate
164+
def test_generate_image_with_all_new_parameters(client, mock_valid_ims_access_token_response):
165+
responses.add(
166+
responses.POST,
167+
IMAGE_URL,
168+
json={
169+
"size": {"width": 1024, "height": 768},
170+
"outputs": [
171+
{"seed": 123, "image": {"url": "https://example.com/image.png"}}
172+
],
173+
"contentClass": "photo",
174+
},
175+
status=200,
176+
)
177+
style = {"presets": ["bw"], "strength": 50}
178+
structure = {"strength": 80, "imageReference": {"source": {"uploadId": "abc123"}}}
179+
response = client.generate_image(
180+
prompt="test prompt",
181+
num_variations=2,
182+
style=style,
183+
structure=structure,
184+
prompt_biasing_locale_code="en-US",
185+
negative_prompt="no text",
186+
seed=42,
187+
aspect_ratio="16:9",
188+
output_format="jpeg",
189+
extra_param="extra_value"
190+
)
191+
# Check the outgoing request body
192+
image_call = responses.calls[1]
193+
body = json.loads(image_call.request.body.decode())
194+
assert body["prompt"] == "test prompt"
195+
assert body["numVariations"] == 2
196+
assert body["style"] == style
197+
assert body["structure"] == structure
198+
assert body["promptBiasingLocaleCode"] == "en-US"
199+
assert body["negativePrompt"] == "no text"
200+
assert body["seed"] == 42
201+
assert body["aspectRatio"] == "16:9"
202+
assert body["outputFormat"] == "jpeg"
203+
assert body["extra_param"] == "extra_value"
204+
# Check the response is parsed correctly
205+
assert response.size.width == 1024
206+
assert response.size.height == 768
207+
assert response.contentClass == "photo"
208+
assert len(response.outputs) == 1
209+
assert response.outputs[0].seed == 123
210+
assert response.outputs[0].image.url == "https://example.com/image.png"
211+
212+
213+
@responses.activate
214+
def test_generate_image_content_class(client, mock_valid_ims_access_token_response):
215+
responses.add(
216+
responses.POST,
217+
IMAGE_URL,
218+
json={
219+
"size": {"width": 512, "height": 512},
220+
"outputs": [
221+
{"seed": 1, "image": {"url": "https://example.com/img.png"}}
222+
],
223+
"contentClass": "photo",
224+
},
225+
status=200,
226+
)
227+
# Valid value: 'photo'
228+
response = client.generate_image(prompt="test", content_class="photo")
229+
image_call = responses.calls[1]
230+
body = json.loads(image_call.request.body.decode())
231+
assert body["contentClass"] == "photo"
232+
assert response.contentClass == "photo"
233+
# Valid value: 'art'
234+
responses.reset()
235+
responses.add(
236+
responses.POST,
237+
IMAGE_URL,
238+
json={
239+
"size": {"width": 512, "height": 512},
240+
"outputs": [
241+
{"seed": 2, "image": {"url": "https://example.com/img2.png"}}
242+
],
243+
"contentClass": "art",
244+
},
245+
status=200,
246+
)
247+
response = client.generate_image(prompt="test", content_class="art")
248+
image_call = responses.calls[0]
249+
body = json.loads(image_call.request.body.decode())
250+
assert body["contentClass"] == "art"
251+
assert response.contentClass == "art"
252+
# Invalid value
253+
with pytest.raises(ValueError):
254+
client.generate_image(prompt="test", content_class="invalid")

0 commit comments

Comments
 (0)