Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions vllm/model_executor/models/siglip2navit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.platforms import _Backend, current_platform

from .vision import get_vit_attn_backend
from vllm.model_executor.layers.rotary_embedding import _apply_rotary_emb_torch

is_hpu = current_platform.is_hpu()

Expand Down Expand Up @@ -144,26 +145,26 @@ def rotate_half(x, interleaved=False):
two=2)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[
x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
],
dim=-1,
)
# def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
# """
# x: (batch_size, seqlen, nheads, headdim)
# cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
# """
# ro_dim = cos.shape[-1] * 2
# assert ro_dim <= x.shape[-1]
# cos = repeat(
# cos,
# "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
# sin = repeat(
# sin,
# "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
# return torch.cat(
# [
# x[..., :ro_dim] * cos +
# rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
# ],
# dim=-1,
# )


def apply_rotary_pos_emb(
Expand All @@ -179,11 +180,11 @@ def apply_rotary_pos_emb(
from flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
else:
apply_rotary_emb_func = apply_rotary_emb_torch
apply_rotary_emb_func = _apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(),
sin.float()).type_as(q)
sin.float(), False).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(),
sin.float()).type_as(k)
sin.float(), False).type_as(k)
return q_embed, k_embed


Expand Down
35 changes: 35 additions & 0 deletions vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
import habana_frameworks.torch as htorch # noqa:F401
import torch
import torch.distributed

from vllm_hpu_extension.debug import init_debug_logger
from vllm_hpu_extension.profiler import (HabanaMemoryProfiler, format_bytes,
setup_profiler)
from vllm_hpu_extension.runtime import get_config
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes

import vllm.envs as envs
Expand All @@ -38,6 +43,12 @@

logger = init_logger(__name__)

def setup_step_profiler(steps):
if steps is None:
return None
step_start, step_end = steps
active = step_end - step_start + 1
return setup_profiler(warmup=0, active=active)

class HPUWorker(LocalOrDistributedWorkerBase):
"""A worker class that executes (a partition of) the model on a HPU.
Expand Down Expand Up @@ -122,6 +133,10 @@ def __init__(
on_trace_ready=fn(torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
self.step = 0
self.profile_steps = get_config().VLLM_PROFILE_STEPS
self.step_profiler = setup_step_profiler(self.profile_steps)
self.step_debug = init_debug_logger('steps')

def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder
Expand Down Expand Up @@ -191,6 +206,10 @@ def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> Optional[List[SamplerOutput]]:
if self.step_debug:
self.step_debug(f'step={self.step}')
if self.step_profiler and self.step == self.profile_steps[0]:
self.step_profiler.start()
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION - will log graph compilations per engine step, only when there was any - highly recommended to use alongside PT_HPU_METRICS_GC_DETAILS! # noqa:E501
# VLLM_HPU_LOG_STEP_GRAPH_COMPILATION_ALL - will log graph compilations per engine step, always, even if there were none # noqa:E501
# VLLM_HPU_LOG_STEP_CPU_FALLBACKS - will log cpu fallbacks per engine step, only when there was any # noqa:E501
Expand Down Expand Up @@ -249,11 +268,27 @@ def execute_model(
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if self.step_profiler:
if self.step >= self.profile_steps[0]:
self.step_profiler.step()
if self.step == self.profile_steps[1]:
self.step_profiler.stop()
self.step_profiler = None
raise RuntimeError('Step profiling finished!')
self.step += 1

return output

output = LocalOrDistributedWorkerBase.execute_model(
self, execute_model_req)
if self.step_profiler:
if self.step >= self.profile_steps[0]:
self.step_profiler.step()
if self.step == self.profile_steps[1]:
self.step_profiler.stop()
self.step_profiler = None
raise RuntimeError('Step profiling finished!')
self.step += 1
return output

@torch.inference_mode()
Expand Down
Loading