Skip to content

Commit 6b29181

Browse files
[python]remove keys to check param dict (#2677)
1 parent 69bbe22 commit 6b29181

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,20 @@ def translate_triton_params(self, parameters: dict) -> dict:
6363
6464
:return: The same parameters dict, but with TensorRT-LLM style parameter names.
6565
"""
66-
if "request_output_len" not in parameters.keys():
66+
if "request_output_len" not in parameters:
6767
parameters["request_output_len"] = parameters.pop(
6868
"max_new_tokens", 30)
69-
if "top_k" in parameters.keys():
69+
if "top_k" in parameters:
7070
parameters["runtime_top_k"] = parameters.pop("top_k")
71-
if "top_p" in parameters.keys():
71+
if "top_p" in parameters:
7272
parameters["runtime_top_p"] = parameters.pop("top_p")
73-
if "seed" in parameters.keys():
73+
if "seed" in parameters:
7474
parameters["random_seed"] = int(parameters.pop("seed"))
7575
if parameters.pop("do_sample", False):
7676
parameters["runtime_top_k"] = parameters.get("runtime_top_k", 5)
7777
parameters["runtime_top_p"] = parameters.get("runtime_top_p", 0.85)
7878
parameters["temperature"] = parameters.get("temperature", 0.8)
79-
if "length_penalty" in parameters.keys():
79+
if "length_penalty" in parameters:
8080
parameters['len_penalty'] = parameters.pop('length_penalty')
8181
parameters["streaming"] = parameters.pop(
8282
"stream", parameters.get("streaming", True))

engines/python/setup/djl_python/transformers_neuronx_scheduler/slot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929

3030
def translate_neuronx_params(parameters: dict) -> dict:
3131
# TODO: Remove this once presence_penalty is supported
32-
if "presence_penalty" in parameters.keys(
33-
) and "repetition_penalty" not in parameters.keys():
32+
if "presence_penalty" in parameters and "repetition_penalty" not in parameters:
3433
parameters["repetition_penalty"] = float(
3534
parameters.pop("presence_penalty")) + 2.0
3635
return parameters

engines/python/setup/djl_python/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def is_best_of(parameters: dict) -> bool:
6969
:param parameters: parameters dictionary
7070
:return: boolean
7171
"""
72-
return "best_of" in parameters.keys() and parameters.get("best_of") > 1
72+
return "best_of" in parameters and parameters.get("best_of") > 1
7373

7474

7575
def is_beam_search(parameters: dict) -> bool:
@@ -78,7 +78,7 @@ def is_beam_search(parameters: dict) -> bool:
7878
:param parameters: parameters dictionary
7979
:return: boolean
8080
"""
81-
return "num_beams" in parameters.keys() and parameters.get("num_beams") > 1
81+
return "num_beams" in parameters and parameters.get("num_beams") > 1
8282

8383

8484
def is_multiple_sequences(parameters: dict) -> bool:
@@ -88,7 +88,7 @@ def is_multiple_sequences(parameters: dict) -> bool:
8888
:param parameters: parameters dictionary
8989
:return: boolean
9090
"""
91-
return "n" in parameters.keys() and parameters.get("n") > 1
91+
return "n" in parameters and parameters.get("n") > 1
9292

9393

9494
def is_streaming(parameters: dict) -> bool:
@@ -97,7 +97,7 @@ def is_streaming(parameters: dict) -> bool:
9797
:param parameters: parameters dictionary
9898
:return: boolean
9999
"""
100-
return "stream" in parameters.keys() and parameters.get("stream")
100+
return "stream" in parameters and parameters.get("stream")
101101

102102

103103
def wait_till_generation_finished(parameters):

0 commit comments

Comments
 (0)