Skip to content

Commit f977ba6

Browse files
TimPansinolrafeei
andauthored
Add TraceCache Guarded Iteration (#704)
* Add MutableMapping API to TraceCache * Update trace cache usage to use guarded APIs. * [Mega-Linter] Apply linters fixes * Bump tests * Fix keys iterator * Comments for trace cache methods * Reorganize tests * Fix fixture refs * Fix testing refs * [Mega-Linter] Apply linters fixes * Bump tests * Upper case constant Co-authored-by: TimPansino <[email protected]> Co-authored-by: Lalleh Rafeei <[email protected]>
1 parent a63e33f commit f977ba6

File tree

6 files changed

+228
-39
lines changed

6 files changed

+228
-39
lines changed

newrelic/core/context.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def log_propagation_failure(s):
4646
elif trace is not None:
4747
self.trace = trace
4848
elif trace_cache_id is not None:
49-
self.trace = self.trace_cache._cache.get(trace_cache_id, None)
49+
self.trace = self.trace_cache.get(trace_cache_id, None)
5050
if self.trace is None:
5151
log_propagation_failure("No trace with id %d." % trace_cache_id)
5252
elif hasattr(request, "_nr_trace") and request._nr_trace is not None:
@@ -60,22 +60,22 @@ def __enter__(self):
6060
self.thread_id = self.trace_cache.current_thread_id()
6161

6262
# Save previous cache contents
63-
self.restore = self.trace_cache._cache.get(self.thread_id, None)
63+
self.restore = self.trace_cache.get(self.thread_id, None)
6464
self.should_restore = True
6565

6666
# Set context in trace cache
67-
self.trace_cache._cache[self.thread_id] = self.trace
67+
self.trace_cache[self.thread_id] = self.trace
6868

6969
return self
7070

7171
def __exit__(self, exc, value, tb):
7272
if self.should_restore:
7373
if self.restore is not None:
7474
# Restore previous contents
75-
self.trace_cache._cache[self.thread_id] = self.restore
75+
self.trace_cache[self.thread_id] = self.restore
7676
else:
7777
# Remove entry from cache
78-
self.trace_cache._cache.pop(self.thread_id)
78+
self.trace_cache.pop(self.thread_id)
7979

8080

8181
def context_wrapper(func, trace=None, request=None, trace_cache_id=None, strict=True):

newrelic/core/trace_cache.py

Lines changed: 82 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
except ImportError:
2929
import _thread as thread
3030

31+
try:
32+
from collections.abc import MutableMapping
33+
except ImportError:
34+
from collections import MutableMapping
35+
3136
from newrelic.core.config import global_settings
3237
from newrelic.core.loop_node import LoopNode
3338

@@ -92,15 +97,15 @@ class TraceCacheActiveTraceError(RuntimeError):
9297
pass
9398

9499

95-
class TraceCache(object):
100+
class TraceCache(MutableMapping):
96101
asyncio = cached_module("asyncio")
97102
greenlet = cached_module("greenlet")
98103

99104
def __init__(self):
100105
self._cache = weakref.WeakValueDictionary()
101106

102107
def __repr__(self):
103-
return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self._cache.items())))
108+
return "<%s object at 0x%x %s>" % (self.__class__.__name__, id(self), str(dict(self.items())))
104109

105110
def current_thread_id(self):
106111
"""Returns the thread ID for the caller.
@@ -135,22 +140,22 @@ def current_thread_id(self):
135140
def task_start(self, task):
136141
trace = self.current_trace()
137142
if trace:
138-
self._cache[id(task)] = trace
143+
self[id(task)] = trace
139144

140145
def task_stop(self, task):
141-
self._cache.pop(id(task), None)
146+
self.pop(id(task), None)
142147

143148
def current_transaction(self):
144149
"""Return the transaction object if one exists for the currently
145150
executing thread.
146151
147152
"""
148153

149-
trace = self._cache.get(self.current_thread_id())
154+
trace = self.get(self.current_thread_id())
150155
return trace and trace.transaction
151156

152157
def current_trace(self):
153-
return self._cache.get(self.current_thread_id())
158+
return self.get(self.current_thread_id())
154159

155160
def active_threads(self):
156161
"""Returns an iterator over all current stack frames for all
@@ -169,7 +174,7 @@ def active_threads(self):
169174
# First yield up those for real Python threads.
170175

171176
for thread_id, frame in sys._current_frames().items():
172-
trace = self._cache.get(thread_id)
177+
trace = self.get(thread_id)
173178
transaction = trace and trace.transaction
174179
if transaction is not None:
175180
if transaction.background_task:
@@ -197,7 +202,7 @@ def active_threads(self):
197202
debug = global_settings().debug
198203

199204
if debug.enable_coroutine_profiling:
200-
for thread_id, trace in list(self._cache.items()):
205+
for thread_id, trace in self.items():
201206
transaction = trace.transaction
202207
if transaction and transaction._greenlet is not None:
203208
gr = transaction._greenlet()
@@ -212,7 +217,7 @@ def prepare_for_root(self):
212217
trace in the cache is from a different task (for asyncio). Returns the
213218
current trace after the cache is updated."""
214219
thread_id = self.current_thread_id()
215-
trace = self._cache.get(thread_id)
220+
trace = self.get(thread_id)
216221
if not trace:
217222
return None
218223

@@ -221,11 +226,11 @@ def prepare_for_root(self):
221226

222227
task = current_task(self.asyncio)
223228
if task is not None and id(trace._task) != id(task):
224-
self._cache.pop(thread_id, None)
229+
self.pop(thread_id, None)
225230
return None
226231

227232
if trace.root and trace.root.exited:
228-
self._cache.pop(thread_id, None)
233+
self.pop(thread_id, None)
229234
return None
230235

231236
return trace
@@ -240,8 +245,8 @@ def save_trace(self, trace):
240245

241246
thread_id = trace.thread_id
242247

243-
if thread_id in self._cache:
244-
cache_root = self._cache[thread_id].root
248+
if thread_id in self:
249+
cache_root = self[thread_id].root
245250
if cache_root and cache_root is not trace.root and not cache_root.exited:
246251
# Cached trace exists and has a valid root still
247252
_logger.error(
@@ -253,7 +258,7 @@ def save_trace(self, trace):
253258

254259
raise TraceCacheActiveTraceError("transaction already active")
255260

256-
self._cache[thread_id] = trace
261+
self[thread_id] = trace
257262

258263
# We judge whether we are actually running in a coroutine by
259264
# seeing if the current thread ID is actually listed in the set
@@ -284,7 +289,7 @@ def pop_current(self, trace):
284289

285290
thread_id = trace.thread_id
286291
parent = trace.parent
287-
self._cache[thread_id] = parent
292+
self[thread_id] = parent
288293

289294
def complete_root(self, root):
290295
"""Completes a trace specified by the given root
@@ -301,7 +306,7 @@ def complete_root(self, root):
301306
to_complete = []
302307

303308
for task_id in task_ids:
304-
entry = self._cache.get(task_id)
309+
entry = self.get(task_id)
305310

306311
if entry and entry is not root and entry.root is root:
307312
to_complete.append(entry)
@@ -316,12 +321,12 @@ def complete_root(self, root):
316321

317322
thread_id = root.thread_id
318323

319-
if thread_id not in self._cache:
324+
if thread_id not in self:
320325
thread_id = self.current_thread_id()
321-
if thread_id not in self._cache:
326+
if thread_id not in self:
322327
raise TraceCacheNoActiveTraceError("no active trace")
323328

324-
current = self._cache.get(thread_id)
329+
current = self.get(thread_id)
325330

326331
if root is not current:
327332
_logger.error(
@@ -333,7 +338,7 @@ def complete_root(self, root):
333338

334339
raise RuntimeError("not the current trace")
335340

336-
del self._cache[thread_id]
341+
del self[thread_id]
337342
root._greenlet = None
338343

339344
def record_event_loop_wait(self, start_time, end_time):
@@ -359,7 +364,7 @@ def record_event_loop_wait(self, start_time, end_time):
359364
task = getattr(transaction.root_span, "_task", None)
360365
loop = get_event_loop(task)
361366

362-
for trace in list(self._cache.values()):
367+
for trace in self.values():
363368
if trace in seen:
364369
continue
365370

@@ -390,6 +395,62 @@ def record_event_loop_wait(self, start_time, end_time):
390395
root.increment_child_count()
391396
root.add_child(node)
392397

398+
# MutableMapping methods
399+
400+
def items(self):
401+
"""
402+
Safely iterates on self._cache.items() indirectly using a list of value references
403+
to avoid RuntimeErrors from size changes during iteration.
404+
"""
405+
for wr in self._cache.valuerefs():
406+
value = wr() # Dereferenced value is potentially no longer live.
407+
if (
408+
value is not None
409+
): # weakref is None means weakref has been garbage collected and is no longer live. Ignore.
410+
yield wr.key, value # wr.key is the original dict key
411+
412+
def keys(self):
413+
"""
414+
Iterates on self._cache.keys() indirectly using a list of value references
415+
to avoid RuntimeErrors from size changes during iteration.
416+
417+
NOTE: Returned keys are keys to weak references which may at any point be garbage collected.
418+
It is only safe to retrieve values from the trace cache using trace_cache.get(key, None).
419+
Retrieving values using trace_cache[key] can cause a KeyError if the item has been garbage collected.
420+
"""
421+
for wr in self._cache.valuerefs():
422+
yield wr.key # wr.key is the original dict key
423+
424+
def values(self):
425+
"""
426+
Safely iterates on self._cache.values() indirectly using a list of value references
427+
to avoid RuntimeErrors from size changes during iteration.
428+
"""
429+
for wr in self._cache.valuerefs():
430+
value = wr() # Dereferenced value is potentially no longer live.
431+
if (
432+
value is not None
433+
): # weakref is None means weakref has been garbage collected and is no longer live. Ignore.
434+
yield value
435+
436+
def __getitem__(self, key):
437+
return self._cache.__getitem__(key)
438+
439+
def __setitem__(self, key, value):
440+
self._cache.__setitem__(key, value)
441+
442+
def __delitem__(self, key):
443+
self._cache.__delitem__(key)
444+
445+
def __iter__(self):
446+
return self.keys()
447+
448+
def __len__(self):
449+
return self._cache.__len__()
450+
451+
def __bool__(self):
452+
return bool(self._cache.__len__())
453+
393454

394455
_trace_cache = TraceCache()
395456

tests/agent_features/test_async_context_propagation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import pytest
16-
from testing_support.fixtures import (
17-
function_not_called,
18-
override_generic_settings,
16+
from testing_support.fixtures import function_not_called, override_generic_settings
17+
from testing_support.validators.validate_transaction_metrics import (
18+
validate_transaction_metrics,
1919
)
20-
from testing_support.validators.validate_transaction_metrics import validate_transaction_metrics
20+
2121
from newrelic.api.application import application_instance as application
2222
from newrelic.api.background_task import BackgroundTask, background_task
2323
from newrelic.api.database_trace import database_trace
@@ -131,7 +131,7 @@ def handle_exception(loop, context):
131131
# The agent should have removed all traces from the cache since
132132
# run_until_complete has terminated (all callbacks scheduled inside the
133133
# task have run)
134-
assert not trace_cache()._cache
134+
assert not trace_cache()
135135

136136
# Assert that no exceptions have occurred
137137
assert not exceptions, exceptions
@@ -286,7 +286,7 @@ def _test():
286286

287287
# The agent should have removed all traces from the cache since
288288
# run_until_complete has terminated
289-
assert not trace_cache()._cache
289+
assert not trace_cache()
290290

291291
# Assert that no exceptions have occurred
292292
assert not exceptions, exceptions

tests/agent_features/test_event_loop_wait_time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def _test():
140140
def test_record_event_loop_wait_outside_task():
141141
# Insert a random trace into the trace cache
142142
trace = FunctionTrace(name="testing")
143-
trace_cache()._cache[0] = trace
143+
trace_cache()[0] = trace
144144

145145
@background_task(name="test_record_event_loop_wait_outside_task")
146146
def _test():

0 commit comments

Comments
 (0)