Skip to content

Commit 485e172

Browse files
authored
Merge pull request #116 from AnFreTh/main
Release v0.2.2
2 parents bc3ddfb + 90dfedb commit 485e172

File tree

10 files changed

+399
-72
lines changed

10 files changed

+399
-72
lines changed

README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,35 @@ You can install STREAM directly from PyPI or from the GitHub repository:
8181
pip install git+https://github.com/AnFreTh/STREAM.git
8282
```
8383

84+
3. **Download necessary NLTK resources**:
85+
86+
To download all necessary NLTK resources required for some models, simply run:
87+
88+
```python
89+
import nltk
90+
91+
def ensure_nltk_resources():
92+
resources = [
93+
"stopwords",
94+
"wordnet",
95+
"punkt_tab",
96+
"brown",
97+
"averaged_perceptron_tagger"
98+
]
99+
for resource in resources:
100+
try:
101+
nltk.data.find(resource)
102+
except LookupError:
103+
try:
104+
print(f"Downloading NLTK resource: {resource}")
105+
nltk.download(resource)
106+
except Exception as e:
107+
print(f"Failed to download {resource}: {e}")
108+
109+
ensure_nltk_resources()
110+
```
111+
112+
84113
3. **Install requirements for add-ons**:
85114
To use STREAMS visualizations, simply run:
86115
```bash

assets/movie_poster_topic1.png

392 KB
Loading

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def run(self):
5454
"plotting": ["dash", "plotly", "matplotlib", "wordcloud"],
5555
"bertopic": ["hdbscan"],
5656
"dcte": ["pyarrow", "setfit"],
57+
"experimental": ["openai"],
5758
}
5859

5960

stream_topic/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""Version information."""
22

33
# The following line *must* be the last in the module, exactly as formatted:
4-
__version__ = "0.2.0"
4+
__version__ = "0.2.2"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .topic_poster import movie_poster
2+
from .topic_story import story_topic
3+
from .topic_summary import topic_summaries
4+
5+
6+
__all__ = [
7+
"movie_poster",
8+
"story_topic",
9+
"topic_summaries",
10+
]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from openai import OpenAI, OpenAIError
2+
from loguru import logger
3+
import os
4+
from IPython.display import Image, display
5+
6+
# Define allowed models
7+
ALLOWED_MODELS = ["dall-e-3", "dall-e-2"]
8+
9+
10+
def movie_poster(
11+
topic,
12+
api_key,
13+
model="dall-e-3",
14+
quality="standard",
15+
prompt="Create a movie poster that depicts that topic best. Note that the words are ordered in decreasing order of their importance.",
16+
size="1024x1024",
17+
return_style="url",
18+
):
19+
"""
20+
Generate a movie-poster-style image based on a given topic using OpenAI's DALL-E model.
21+
22+
Parameters:
23+
- topic: List of words/phrases or list of tuples (word, importance) representing the topic.
24+
- api_key: API key for OpenAI.
25+
- model: Model to use (e.g., 'dall-e').
26+
- poster_style: Description of the style for the image, default is "Movie Poster".
27+
- content: Initial system content.
28+
- prompt: Prompt for the image generation.
29+
- size: Size of the generated image, default is "1024x1024".
30+
31+
Returns:
32+
- image_url: URL of the generated image.
33+
"""
34+
35+
# Load the API key from environment if not provided
36+
if api_key is None:
37+
api_key = os.getenv("OPENAI_API_KEY")
38+
39+
if api_key is None:
40+
raise ValueError("API key is missing. Please provide an API key.")
41+
42+
assert return_style in [
43+
"url",
44+
"plot",
45+
], "Invalid return style. Please choose 'url' or 'plot'"
46+
47+
# Initialize the OpenAI client with your API key
48+
client = OpenAI(api_key=api_key)
49+
50+
# Validate model
51+
if model not in ALLOWED_MODELS:
52+
raise ValueError(
53+
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}."
54+
)
55+
56+
# Create the prompt for the movie poster
57+
if isinstance(topic[0], tuple):
58+
# If the topic is a list of tuples with importance
59+
topic_description = ", ".join(
60+
[f"{word} (importance: {importance})" for word, importance in topic]
61+
)
62+
else:
63+
# If the topic is a list of words in descending importance
64+
topic_description = topic
65+
66+
image_prompt = f"Given the following topic: {topic_description}. {prompt}"
67+
68+
# Logging the operation
69+
logger.info(f"--- Generating image with model: {model} ---")
70+
response = client.images.generate(
71+
model=model,
72+
prompt=image_prompt,
73+
size=size,
74+
quality=quality,
75+
n=1,
76+
)
77+
78+
# Ensure the response is valid
79+
if response:
80+
image_url = response.data[0].url
81+
else:
82+
image_url = "No image generated. Please try again."
83+
84+
if return_style == "url":
85+
return image_url
86+
87+
elif return_style == "plot":
88+
display(Image(url=image_url))
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from openai import OpenAI, OpenAIError
2+
from loguru import logger
3+
import os
4+
5+
6+
# Define allowed models
7+
ALLOWED_MODELS = [
8+
"gpt-3.5-turbo",
9+
"gpt-3.5-turbo-16k",
10+
"gpt-4",
11+
"gpt-4o",
12+
"gpt-4o-mini",
13+
"gpt-4-turbo",
14+
]
15+
16+
17+
def story_topic(
18+
topic,
19+
api_key,
20+
model="gpt-3.5-turbo-16k",
21+
content="You are a creative writer.",
22+
prompt="Create a creative story that includes the following words:",
23+
max_tokens=250,
24+
temperature=0.8,
25+
top_p=1.0,
26+
frequency_penalty=0.0,
27+
presence_penalty=0.0,
28+
):
29+
"""
30+
Generate a creative story using OpenAI's GPT model.
31+
32+
Parameters:
33+
- topic: List of words or phrases to include in the story.
34+
- api_key: API key for OpenAI.
35+
- model: Model to use (e.g., 'gpt-3.5-turbo-16k').
36+
- content: Initial system content.
37+
- prompt: Prompt for the story generation.
38+
- max_tokens: Maximum tokens for the response.
39+
- temperature: Creativity level for the model.
40+
- top_p: Nucleus sampling parameter.
41+
- frequency_penalty: Penalty for word frequency.
42+
- presence_penalty: Penalty for word presence.
43+
44+
Returns:
45+
- story: Generated story as a string.
46+
"""
47+
48+
# Load the API key from environment if not provided
49+
if api_key is None:
50+
api_key = os.getenv("OPENAI_API_KEY")
51+
52+
# Initialize the OpenAI client with your API key
53+
client = OpenAI(api_key=api_key)
54+
55+
# Validate model
56+
if model not in ALLOWED_MODELS:
57+
raise ValueError(
58+
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}."
59+
)
60+
61+
# Create the prompt
62+
prompt = f"{prompt}: {', '.join(topic)}. Make it as short as {max_tokens} words."
63+
64+
# Logging the operation
65+
logger.info(f"--- Generating story with model: {model} ---")
66+
67+
try:
68+
response = client.chat.completions.create(
69+
model=model,
70+
messages=[
71+
{"role": "system", "content": content},
72+
{"role": "user", "content": prompt},
73+
],
74+
max_tokens=max_tokens,
75+
temperature=temperature,
76+
top_p=top_p,
77+
frequency_penalty=frequency_penalty,
78+
presence_penalty=presence_penalty,
79+
)
80+
81+
# Ensure the response is valid
82+
if response and len(response.choices) > 0:
83+
story = response.choices[0].message.content
84+
else:
85+
story = "No story generated. Please try again."
86+
87+
return story
88+
89+
except OpenAIError as e:
90+
return f"An error occurred: {str(e)}"
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from openai import OpenAI, OpenAIError
2+
from loguru import logger
3+
import os
4+
5+
ALLOWED_MODELS = [
6+
"gpt-3.5-turbo",
7+
"gpt-3.5-turbo-16k",
8+
"gpt-4",
9+
"gpt-4o",
10+
"gpt-4o-mini",
11+
"gpt-4-turbo",
12+
]
13+
14+
15+
def topic_summaries(
16+
topics,
17+
api_key,
18+
model="gpt-3.5-turbo-16k",
19+
content="You are a creative writer.",
20+
prompt="Provide a 1-2 sentence summary for the following topic:",
21+
max_tokens=60,
22+
temperature=0.7,
23+
top_p=1.0,
24+
frequency_penalty=0.0,
25+
presence_penalty=0.0,
26+
):
27+
"""
28+
Generate a 1-2 sentence summary for each topic using OpenAI's GPT model.
29+
30+
Parameters:
31+
- topics: List of lists, where each sublist contains words/phrases representing a topic.
32+
- api_key: API key for OpenAI.
33+
- model: Model to use (e.g., 'gpt-3.5-turbo-16k').
34+
- content: Initial system content.
35+
- prompt: Prompt for the summary generation.
36+
- max_tokens: Maximum tokens for each summary.
37+
- temperature: Creativity level for the model.
38+
- top_p: Nucleus sampling parameter.
39+
- frequency_penalty: Penalty for word frequency.
40+
- presence_penalty: Penalty for word presence.
41+
42+
Returns:
43+
- summaries: List of summaries corresponding to each topic.
44+
"""
45+
46+
# Load the API key from environment if not provided
47+
if api_key is None:
48+
api_key = os.getenv("OPENAI_API_KEY")
49+
50+
# Initialize the OpenAI client with your API key
51+
client = OpenAI(api_key=api_key)
52+
53+
# Validate model
54+
if model not in ALLOWED_MODELS:
55+
raise ValueError(
56+
f"Invalid model. Please choose a valid model from {ALLOWED_MODELS}."
57+
)
58+
59+
summaries = []
60+
61+
for idx, topic in enumerate(topics):
62+
# Create the prompt for each topic
63+
topic_prompt = f"{prompt} {', '.join(topic)}."
64+
65+
# Logging the operation
66+
logger.info(f"--- Generating summary for topic {idx} with model: {model} ---")
67+
68+
try:
69+
response = client.chat.completions.create(
70+
model=model,
71+
messages=[
72+
{"role": "system", "content": content},
73+
{"role": "user", "content": topic_prompt},
74+
],
75+
max_tokens=max_tokens,
76+
temperature=temperature,
77+
top_p=top_p,
78+
frequency_penalty=frequency_penalty,
79+
presence_penalty=presence_penalty,
80+
)
81+
82+
# Ensure the response is valid
83+
if response and len(response.choices) > 0:
84+
summary = response.choices[0].message.content
85+
else:
86+
summary = "No summary generated. Please try again."
87+
88+
summaries.append(summary)
89+
90+
except OpenAIError as e:
91+
summaries.append(f"An error occurred: {str(e)}")
92+
93+
return summaries

0 commit comments

Comments
 (0)