Skip to content

Commit 1013ecb

Browse files
matthiasdieneralexfiklinducer
authored
PytatoPyOpenCLArrayContext: add support for kernel profiling (#311)
* Fix ruff 0.11.4 error * PytatoPyOpenCLArrayContext: add support for kernel profiling * add a simple test * remove global pyopencl import * replace argument indicator Co-authored-by: Alexandru Fikl <[email protected]> * add add_profiling_event * outline multi-kernel per t_unit implementation * small fixes * refactor tabulate_profiling_data * add test to disable * rename to private fields/method Co-authored-by: Andreas Klöckner <[email protected]> * refactor to simplify API * bit more doc * factor out profile enable/disable --------- Co-authored-by: Alexandru Fikl <[email protected]> Co-authored-by: Andreas Klöckner <[email protected]>
1 parent 9300134 commit 1013ecb

File tree

4 files changed

+231
-5
lines changed

4 files changed

+231
-5
lines changed

arraycontext/impl/pytato/__init__.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import abc
5555
import sys
5656
from collections.abc import Callable
57+
from dataclasses import dataclass
5758
from typing import TYPE_CHECKING, Any
5859

5960
import numpy as np
@@ -74,7 +75,6 @@
7475

7576
if TYPE_CHECKING:
7677
import loopy as lp
77-
import pyopencl as cl
7878
import pytato
7979

8080
if getattr(sys, "_BUILDING_SPHINX_DOCS", False):
@@ -235,6 +235,16 @@ def get_target(self):
235235

236236
# {{{ PytatoPyOpenCLArrayContext
237237

238+
239+
@dataclass
240+
class ProfileEvent:
241+
"""Holds a profile event that has not been collected by the profiler yet."""
242+
243+
start_cl_event: cl._cl.Event
244+
stop_cl_event: cl._cl.Event
245+
t_unit_name: str
246+
247+
238248
class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext):
239249
"""
240250
An :class:`ArrayContext` that uses :mod:`pytato` data types to represent
@@ -259,7 +269,7 @@ def __init__(
259269
self, queue: cl.CommandQueue, allocator=None, *,
260270
use_memory_pool: bool | None = None,
261271
compile_trace_callback: Callable[[Any, str, Any], None] | None = None,
262-
272+
profile_kernels: bool = False,
263273
# do not use: only for testing
264274
_force_svm_arg_limit: int | None = None,
265275
) -> None:
@@ -322,6 +332,59 @@ def __init__(
322332

323333
self._force_svm_arg_limit = _force_svm_arg_limit
324334

335+
self._enable_profiling(profile_kernels)
336+
337+
# {{{ Profiling functionality
338+
339+
def _enable_profiling(self, enable: bool) -> None:
340+
# List of ProfileEvents that haven't been transferred to profiled
341+
# results yet
342+
self._profile_events: list[ProfileEvent] = []
343+
344+
# Dict of kernel name -> list of kernel execution times
345+
self._profile_results: dict[str, list[int]] = {}
346+
347+
if enable:
348+
import pyopencl as cl
349+
if not self.queue.properties & cl.command_queue_properties.PROFILING_ENABLE:
350+
raise RuntimeError("Profiling was not enabled in the command queue. "
351+
"Please create the queue with "
352+
"cl.command_queue_properties.PROFILING_ENABLE.")
353+
self.profile_kernels = True
354+
355+
else:
356+
self.profile_kernels = False
357+
358+
def _wait_and_transfer_profile_events(self) -> None:
359+
"""Wait for all profiling events to finish and transfer the results
360+
to *self._profile_results*."""
361+
import pyopencl as cl
362+
# First, wait for completion of all events
363+
if self._profile_events:
364+
cl.wait_for_events([p_event.stop_cl_event
365+
for p_event in self._profile_events])
366+
367+
# Then, collect all events and store them
368+
for t in self._profile_events:
369+
name = t.t_unit_name
370+
371+
time = t.stop_cl_event.profile.end - t.start_cl_event.profile.end
372+
373+
self._profile_results.setdefault(name, []).append(time)
374+
375+
self._profile_events = []
376+
377+
def _add_profiling_events(self, start: cl._cl.Event, stop: cl._cl.Event,
378+
t_unit_name: str) -> None:
379+
"""Add profiling events to the list of profiling events."""
380+
self._profile_events.append(ProfileEvent(start, stop, t_unit_name))
381+
382+
def _reset_profiling_data(self) -> None:
383+
"""Reset profiling data."""
384+
self._profile_results = {}
385+
386+
# }}}
387+
325388
@property
326389
def _frozen_array_types(self) -> tuple[type, ...]:
327390
import pyopencl.array as cla
@@ -546,9 +609,18 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray:
546609
self._dag_transform_cache[normalized_expr])
547610

548611
assert len(pt_prg.bound_arguments) == 0
549-
_evt, out_dict = pt_prg(self.queue,
612+
613+
if self.profile_kernels:
614+
import pyopencl as cl
615+
start_evt = cl.enqueue_marker(self.queue)
616+
617+
evt, out_dict = pt_prg(self.queue,
550618
allocator=self.allocator,
551619
**bound_arguments)
620+
621+
if self.profile_kernels:
622+
self._add_profiling_events(start_evt, evt, pt_prg.program.entrypoint)
623+
552624
assert len(set(out_dict) & set(key_to_frozen_subary)) == 0
553625

554626
key_to_frozen_subary = {

arraycontext/impl/pytato/compile.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -636,10 +636,17 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
636636
input_kwargs_for_loopy = _args_to_device_buffers(
637637
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
638638

639-
_evt, out_dict = self.pytato_program(queue=self.actx.queue,
639+
if self.actx.profile_kernels:
640+
import pyopencl as cl
641+
start_evt = cl.enqueue_marker(self.actx.queue)
642+
643+
evt, out_dict = self.pytato_program(queue=self.actx.queue,
640644
allocator=self.actx.allocator,
641645
**input_kwargs_for_loopy)
642646

647+
if self.actx.profile_kernels:
648+
self.actx._add_profiling_events(start_evt, evt, fn_name)
649+
643650
def to_output_template(keys, _):
644651
name_in_program = self.output_id_to_name_in_program[keys]
645652
return self.actx.thaw(to_tagged_cl_array(
@@ -675,10 +682,17 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
675682
input_kwargs_for_loopy = _args_to_device_buffers(
676683
self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)
677684

678-
_evt, out_dict = self.pytato_program(queue=self.actx.queue,
685+
if self.actx.profile_kernels:
686+
import pyopencl as cl
687+
start_evt = cl.enqueue_marker(self.actx.queue)
688+
689+
evt, out_dict = self.pytato_program(queue=self.actx.queue,
679690
allocator=self.actx.allocator,
680691
**input_kwargs_for_loopy)
681692

693+
if self.actx.profile_kernels:
694+
self.actx._add_profiling_events(start_evt, evt, fn_name)
695+
682696
return self.actx.thaw(to_tagged_cl_array(out_dict[self.output_name],
683697
axes=get_cl_axes_from_pt_axes(
684698
self.output_axes),

arraycontext/impl/pytato/utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
__doc__ = """
55
.. autofunction:: transfer_from_numpy
66
.. autofunction:: transfer_to_numpy
7+
8+
9+
Profiling-related functions
10+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
11+
12+
.. autofunction:: tabulate_profiling_data
713
"""
814

915

@@ -35,6 +41,7 @@
3541
from collections.abc import Mapping
3642
from typing import TYPE_CHECKING, Any, cast
3743

44+
import pytools
3845
from pytato.array import (
3946
AbstractResultWithNamedArrays,
4047
Array,
@@ -51,6 +58,7 @@
5158

5259
from arraycontext import ArrayContext
5360
from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis
61+
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
5462

5563

5664
if TYPE_CHECKING:
@@ -221,4 +229,42 @@ def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames:
221229

222230
# }}}
223231

232+
233+
# {{{ Profiling
234+
235+
def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table:
236+
"""Return a :class:`pytools.Table` with the profiling results."""
237+
actx._wait_and_transfer_profile_events()
238+
239+
tbl = pytools.Table()
240+
241+
# Table header
242+
tbl.add_row(("Kernel", "# Calls", "Time_sum [ns]", "Time_avg [ns]"))
243+
244+
# Precision of results
245+
g = ".5g"
246+
247+
total_calls = 0
248+
total_time = 0.0
249+
250+
for kernel_name, times in actx._profile_results.items():
251+
num_calls = len(times)
252+
total_calls += num_calls
253+
254+
t_sum = sum(times)
255+
t_avg = t_sum / num_calls
256+
if t_sum is not None:
257+
total_time += t_sum
258+
259+
tbl.add_row((kernel_name, num_calls, f"{t_sum:{g}}", f"{t_avg:{g}}"))
260+
261+
tbl.add_row(("", "", "", ""))
262+
tbl.add_row(("Total", total_calls, f"{total_time:{g}}", "--"))
263+
264+
actx._reset_profiling_data()
265+
266+
return tbl
267+
268+
# }}}
269+
224270
# vim: foldmethod=marker

test/test_pytato_arraycontext.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import logging
2828

29+
import numpy as np
2930
import pytest
3031

3132
from pytools.tag import Tag
@@ -274,6 +275,99 @@ def twice(x, y, a):
274275
assert isinstance(ep.arg_dict["_actx_in_2"], lp.ArrayArg)
275276

276277

278+
def test_profiling_actx():
279+
import pyopencl as cl
280+
cl_ctx = cl.create_some_context()
281+
queue = cl.CommandQueue(cl_ctx,
282+
properties=cl.command_queue_properties.PROFILING_ENABLE)
283+
284+
actx = PytatoPyOpenCLArrayContext(queue, profile_kernels=True)
285+
286+
def twice(x):
287+
return 2 * x
288+
289+
# {{{ Compiled test
290+
291+
f = actx.compile(twice)
292+
293+
assert len(actx._profile_events) == 0
294+
295+
for _ in range(10):
296+
assert actx.to_numpy(f(99)) == 198
297+
298+
assert len(actx._profile_events) == 10
299+
actx._wait_and_transfer_profile_events()
300+
assert len(actx._profile_events) == 0
301+
assert len(actx._profile_results) == 1
302+
assert len(actx._profile_results["twice"]) == 10
303+
304+
from arraycontext.impl.pytato.utils import tabulate_profiling_data
305+
306+
print(tabulate_profiling_data(actx))
307+
assert len(actx._profile_results) == 0
308+
309+
# }}}
310+
311+
# {{{ Uncompiled/frozen test
312+
313+
assert len(actx._profile_events) == 0
314+
315+
for _ in range(10):
316+
assert np.all(actx.to_numpy(twice(actx.from_numpy(np.array([99, 99])))) == 198)
317+
318+
assert len(actx._profile_events) == 10
319+
actx._wait_and_transfer_profile_events()
320+
assert len(actx._profile_events) == 0
321+
assert len(actx._profile_results) == 1
322+
assert len(actx._profile_results["frozen_result"]) == 10
323+
324+
print(tabulate_profiling_data(actx))
325+
326+
assert len(actx._profile_results) == 0
327+
328+
# }}}
329+
330+
# {{{ test disabling profiling
331+
332+
actx._enable_profiling(False)
333+
334+
assert len(actx._profile_events) == 0
335+
336+
for _ in range(10):
337+
assert actx.to_numpy(f(99)) == 198
338+
339+
assert len(actx._profile_events) == 0
340+
assert len(actx._profile_results) == 0
341+
342+
# }}}
343+
344+
# {{{ test enabling profiling
345+
346+
actx._enable_profiling(True)
347+
348+
assert len(actx._profile_events) == 0
349+
350+
for _ in range(10):
351+
assert actx.to_numpy(f(99)) == 198
352+
353+
assert len(actx._profile_events) == 10
354+
actx._wait_and_transfer_profile_events()
355+
assert len(actx._profile_events) == 0
356+
assert len(actx._profile_results) == 1
357+
358+
# }}}
359+
360+
queue2 = cl.CommandQueue(cl_ctx)
361+
362+
with pytest.raises(RuntimeError):
363+
PytatoPyOpenCLArrayContext(queue2, profile_kernels=True)
364+
365+
actx2 = PytatoPyOpenCLArrayContext(queue2)
366+
367+
with pytest.raises(RuntimeError):
368+
actx2._enable_profiling(True)
369+
370+
277371
if __name__ == "__main__":
278372
import sys
279373
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)