|
36 | 36 | "[Azure OpenAI Studio](https://oai.azure.com/resource/overview)"
|
37 | 37 | "[View the source code](https://github.com/ks6088ts-labs/workshop-azure-openai/blob/main/apps/8_streamlit_azure_openai_batch/main.py)"
|
38 | 38 |
|
| 39 | + |
| 40 | +def is_configured(): |
| 41 | + return azure_openai_api_key and azure_openai_endpoint and azure_openai_api_version and azure_openai_gpt_model |
| 42 | + |
| 43 | + |
| 44 | +def get_client(): |
| 45 | + return AzureOpenAI( |
| 46 | + api_key=azure_openai_api_key, |
| 47 | + api_version=azure_openai_api_version, |
| 48 | + azure_endpoint=azure_openai_endpoint, |
| 49 | + ) |
| 50 | + |
| 51 | + |
39 | 52 | st.title("8_streamlit_azure_openai_batch")
|
40 | 53 |
|
41 |
| -if not azure_openai_api_key or not azure_openai_endpoint or not azure_openai_api_version or not azure_openai_gpt_model: |
| 54 | +if not is_configured(): |
42 | 55 | st.warning("Please fill in the required fields at the sidebar.")
|
43 |
| - st.stop() |
44 | 56 |
|
45 | 57 | # ---------------
|
46 | 58 | # Upload batch file
|
47 | 59 | # ---------------
|
48 | 60 | st.header("Upload batch file")
|
49 | 61 | st.info("Upload a file in JSON lines format (.jsonl)")
|
50 |
| -client = AzureOpenAI( |
51 |
| - api_key=azure_openai_api_key, |
52 |
| - api_version=azure_openai_api_version, |
53 |
| - azure_endpoint=azure_openai_endpoint, |
54 |
| -) |
55 | 62 | uploaded_file = st.file_uploader("Upload an input file in JSON lines format", type=("jsonl"))
|
56 | 63 | if uploaded_file:
|
57 | 64 | bytes_data = uploaded_file.read()
|
58 | 65 | st.write(bytes_data.decode().split("\n"))
|
59 |
| - submit_button = st.button("Submit", key="submit") |
60 |
| - if submit_button: |
| 66 | + if st.button( |
| 67 | + "Submit", |
| 68 | + key="submit", |
| 69 | + disabled=not is_configured(), |
| 70 | + ): |
61 | 71 | temp_file_path = "tmp.jsonl"
|
62 | 72 | with open(temp_file_path, "wb") as f:
|
63 | 73 | f.write(bytes_data)
|
64 | 74 | with st.spinner("Uploading..."):
|
65 | 75 | try:
|
66 |
| - response = client.files.create( |
| 76 | + response = get_client().files.create( |
67 | 77 | # FIXME: hardcoded for now, use uploaded_file
|
68 | 78 | file=open(temp_file_path, "rb"),
|
69 | 79 | purpose="batch",
|
|
83 | 93 | key="track_file_id",
|
84 | 94 | help="Enter the file ID to track the file upload status",
|
85 | 95 | )
|
86 |
| -track_button = st.button("Track") |
87 |
| -if track_file_id != "" and track_button: |
| 96 | +if st.button( |
| 97 | + "Track", |
| 98 | + key="track", |
| 99 | + disabled=not track_file_id or not is_configured(), |
| 100 | +): |
88 | 101 | with st.spinner("Tracking..."):
|
89 | 102 | try:
|
90 |
| - response = client.files.retrieve(track_file_id) |
| 103 | + response = get_client().files.retrieve(track_file_id) |
91 | 104 | st.write(response.model_dump())
|
92 | 105 | st.write(f"status: {response.status}")
|
93 | 106 | except Exception as e:
|
|
104 | 117 | key="batch_file_id",
|
105 | 118 | help="Enter the file ID to track the file upload status",
|
106 | 119 | )
|
107 |
| -batch_button = st.button("Create batch job") |
108 |
| -if batch_file_id != "" and batch_button: |
| 120 | +if st.button( |
| 121 | + "Create batch job", |
| 122 | + key="create", |
| 123 | + disabled=not batch_file_id or not is_configured(), |
| 124 | +): |
109 | 125 | with st.spinner("Creating..."):
|
110 | 126 | try:
|
111 |
| - response = client.batches.create( |
| 127 | + response = get_client().batches.create( |
112 | 128 | input_file_id=batch_file_id,
|
113 | 129 | endpoint="/chat/completions",
|
114 | 130 | completion_window="24h",
|
|
128 | 144 | key="track_batch_job_id",
|
129 | 145 | help="Enter the batch job ID to track the job progress",
|
130 | 146 | )
|
131 |
| -track_batch_job_button = st.button("Track batch job") |
132 |
| -if track_batch_job_id != "" and track_batch_job_button: |
| 147 | +if st.button( |
| 148 | + "Track batch job", |
| 149 | + key="track_batch_job", |
| 150 | + disabled=not track_batch_job_id or not is_configured(), |
| 151 | +): |
133 | 152 | with st.spinner("Tracking..."):
|
134 | 153 | try:
|
135 |
| - response = client.batches.retrieve(track_batch_job_id) |
| 154 | + response = get_client().batches.retrieve(track_batch_job_id) |
136 | 155 | st.write(response.model_dump())
|
137 | 156 | st.write(f"status: {response.status}")
|
138 | 157 | st.write(f"output_file_id: {response.output_file_id}")
|
|
150 | 169 | key="retrieve_batch_job_id",
|
151 | 170 | help="Enter the batch job ID to retrieve the output file",
|
152 | 171 | )
|
153 |
| -retrieve_batch_job_button = st.button("Retrieve batch job output file") |
154 |
| -if output_file_id != "" and retrieve_batch_job_button: |
| 172 | +if st.button( |
| 173 | + "Retrieve batch job output file", |
| 174 | + key="retrieve_batch_job", |
| 175 | + disabled=not output_file_id or not is_configured(), |
| 176 | +): |
155 | 177 | with st.spinner("Retrieving..."):
|
156 | 178 | try:
|
157 |
| - file_response = client.files.content(output_file_id) |
| 179 | + file_response = get_client().files.content(output_file_id) |
158 | 180 | raw_responses = file_response.text.strip().split("\n")
|
159 | 181 |
|
160 | 182 | for raw_response in raw_responses:
|
|
0 commit comments