Skip to content

Commit c67e8de

Browse files
authored
Fix for Ray MBridge, NeMo logprob eval benchmarks (#445)
Signed-off-by: Abhishree <abhishreetm@gmail.com>
1 parent 6558d7f commit c67e8de

File tree

2 files changed

+538
-8
lines changed

2 files changed

+538
-8
lines changed

nemo_deploy/llm/megatronllm_deployable_ray.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import random
@@ -264,7 +265,6 @@ def __init__(
264265
async def completions(self, request: Dict[Any, Any]):
265266
"""Handle text completion requests."""
266267
try:
267-
print("request", request)
268268
if "prompt" in request:
269269
request["prompts"] = [request["prompt"]]
270270
temperature = request.get("temperature", 0.0)
@@ -290,7 +290,6 @@ async def completions(self, request: Dict[Any, Any]):
290290

291291
# Run tokenization and model inference in the thread pool
292292
results = ray.get(self.primary_worker.infer.remote(inference_inputs))
293-
print("results", results)
294293
# Extract generated texts from results
295294
generated_texts = results.get("sentences", [])
296295

@@ -302,12 +301,14 @@ async def completions(self, request: Dict[Any, Any]):
302301
# Convert numpy arrays to Python lists for JSON serialization
303302
log_probs_data = results.get("log_probs", None)
304303
if log_probs_data is not None and isinstance(log_probs_data, np.ndarray):
305-
log_probs_data = log_probs_data.tolist()
304+
# log_probs_data is present as list of numpy array, just take the first element to convert to list
305+
log_probs_data = log_probs_data.tolist()[0]
306306

307-
# Convert numpy arrays to Python lists for JSON serialization
308307
top_log_probs_data = results.get("top_logprobs", None)
309-
if top_log_probs_data is not None and isinstance(top_log_probs_data, np.ndarray):
310-
top_log_probs_data = top_log_probs_data.tolist()
308+
if top_log_probs_data is not None:
309+
# top_log_probs_data[0] is a string, parse it as JSON. top_log_probs_data is list of string, so
310+
# just take the first element to convert to json
311+
top_log_probs_data = json.loads(top_log_probs_data[0])
311312

312313
output = {
313314
"id": f"cmpl-{int(time.time())}",
@@ -339,8 +340,9 @@ async def completions(self, request: Dict[Any, Any]):
339340
}
340341
if request.get("echo", False):
341342
# output format requires empty logprobs for the 1st token if echo is True
342-
output["choices"][0]["logprobs"]["token_logprobs"][0].insert(0, None)
343-
print("output", output)
343+
output["choices"][0]["logprobs"]["token_logprobs"].insert(0, None)
344+
# Comment out the below line to check the output in case if invalid accuracy score or output.
345+
# LOGGER.warning(f"Output: {output}")
344346
return output
345347
except Exception as e:
346348
LOGGER.error(f"Error during inference: {str(e)}")

0 commit comments

Comments
 (0)