diff --git a/configs/common/Options.py b/configs/common/Options.py index 937bdecac4..441840b5c8 100644 --- a/configs/common/Options.py +++ b/configs/common/Options.py @@ -306,6 +306,11 @@ def addCommonOptions(parser, configure_xiangshan=False): help=""" Prefetching cache level for SMS'pht""") + parser.add_argument("--enable-pf-buffer", action="store_true", default=False, + help=""" + Force all hardware prefetchers to enable their + optional prefetch buffer (QueuedPrefetcher.use_pf_buffer).""") + parser.add_argument("--cpu-clock", action="store", type=str, default='3GHz', help="Clock for blocks running at CPU speed") diff --git a/configs/common/PrefetcherConfig.py b/configs/common/PrefetcherConfig.py index 3e3704ec2a..44387ba4d8 100644 --- a/configs/common/PrefetcherConfig.py +++ b/configs/common/PrefetcherConfig.py @@ -20,6 +20,10 @@ def create_prefetcher(cpu, cache_level, options): prefetcher = _get_hwp(prefetcher_name) print(f"create_prefetcher at {cache_level}: {prefetcher_name}") + if prefetcher != NULL and getattr(options, 'enable_pf_buffer', False): + if hasattr(prefetcher, 'use_pf_buffer'): + prefetcher.use_pf_buffer = True + if prefetcher == NULL: return NULL @@ -59,6 +63,10 @@ def create_prefetcher(cpu, cache_level, options): prefetcher.enable_activepage = False prefetcher.enable_pht = True prefetcher.enable_xsstream = True + prefetcher.prefetch_train = False # disable L1PF train L2 + # disable unecessary filter to align with RTL when in pf_buffer mode + if hasattr(prefetcher, 'queue_filter'): + prefetcher.queue_filter = False if cache_level == 'l2': if options.classic_l2: @@ -72,6 +80,10 @@ def create_prefetcher(cpu, cache_level, options): prefetcher.enable_despacito_stream = False prefetcher.bop_large = XSVirtualLargeBOP(is_sub_prefetcher=True,enable_adaptoffset=False) prefetcher.bop_small = XSPhysicalSmallBOP(is_sub_prefetcher=True,enable_adaptoffset=False) + prefetcher.prefetch_train = False # disable L1PF train L2 + # disable unecessary filter to align with RTL when in pf_buffer mode + if hasattr(prefetcher, 'queue_filter'): + prefetcher.queue_filter = False if options.l1_to_l2_pf_hint: prefetcher.queue_size = 64 prefetcher.max_prefetch_requests_with_pending_translation = 128 @@ -88,10 +100,17 @@ def create_prefetcher(cpu, cache_level, options): prefetcher.enable_bop = True prefetcher.enable_cdp = False prefetcher.enable_despacito_stream = False + if prefetcher.enable_despacito_stream: + # if you want to check despacito pattern trace, set this to True + prefetcher.despacito_stream.enable_despacito_db = False prefetcher.bop_large = XSVirtualLargeBOP(is_sub_prefetcher=True,enable_adaptoffset=False) prefetcher.bop_small = XSPhysicalSmallBOP(is_sub_prefetcher=True,enable_adaptoffset=False) + prefetcher.prefetch_train = False # disable L1PF train L2 + # disable unecessary filter to align with RTL when in pf_buffer mode + if hasattr(prefetcher, 'queue_filter'): + prefetcher.queue_filter = False if options.l1_to_l2_pf_hint: - prefetcher.queue_size = 64 + prefetcher.queue_size = 32 prefetcher.max_prefetch_requests_with_pending_translation = 128 if cache_level == 'l3': diff --git a/configs/common/xiangshan.py b/configs/common/xiangshan.py index ed78ebc922..4acf8ff4a1 100644 --- a/configs/common/xiangshan.py +++ b/configs/common/xiangshan.py @@ -708,6 +708,7 @@ def _finish_xiangshan_system(args, test_sys, TestCPUClass, ruby): test_sys.arch_db.dump_l3_evict_trace = False test_sys.arch_db.dump_l1_miss_trace = False test_sys.arch_db.dump_bop_train_trace = False + test_sys.arch_db.dump_stride_train_trace = False test_sys.arch_db.dump_sms_train_trace = False test_sys.arch_db.dump_vaddr_trace = False test_sys.arch_db.dump_lifetime = False @@ -780,6 +781,29 @@ def _finish_xiangshan_system(args, test_sys, TestCPUClass, ruby): "Conf INT NOT NULL," \ "Miss BOOL NOT NULL," \ "SITE TEXT);" + , + "CREATE TABLE StrideTrainTrace(" \ + "ID INTEGER PRIMARY KEY AUTOINCREMENT," \ + "Tick INT NOT NULL," \ + "Addr INT NOT NULL," \ + "PC INT NOT NULL," \ + "HashPC INT NOT NULL," \ + "QueryHit BOOL NOT NULL," \ + "IsFirstShot BOOL NOT NULL," \ + "Miss BOOL NOT NULL," \ + "IsTrain BOOL NOT NULL," \ + "SITE TEXT);" + , + "CREATE TABLE DespacitoTrainTrace(" \ + "ID INTEGER PRIMARY KEY AUTOINCREMENT," \ + "Tick INT NOT NULL," \ + "vAddr INT NOT NULL," \ + "pAddr INT NOT NULL," \ + "PC INT NOT NULL," \ + "hasPC BOOL NOT NULL," \ + "Miss BOOL NOT NULL," \ + "IsTrain BOOL NOT NULL," \ + "SITE TEXT);" ,# perfCounter CommitTrace perfCCT_cmd ] diff --git a/configs/example/idealkmhv3.py b/configs/example/idealkmhv3.py index 76d6d4800f..68f2c9b625 100644 --- a/configs/example/idealkmhv3.py +++ b/configs/example/idealkmhv3.py @@ -138,7 +138,8 @@ def setKmhV3IdealParams(args, system): # If user didn't specify bp_type, set default based on ideal_kmhv3 args.bp_type = 'DecoupledBPUWithBTB' args.l2_size = '2MB' - + # Enable prefetch buffers for all hardware prefetchers in this config. + args.enable_pf_buffer = False # Match the memories with the CPUs, based on the options for the test system TestMemClass = Simulation.setMemClass(args) diff --git a/configs/example/kmhv2.py b/configs/example/kmhv2.py index 206203f2ea..e3407480e6 100644 --- a/configs/example/kmhv2.py +++ b/configs/example/kmhv2.py @@ -30,7 +30,8 @@ # disable l1 berti, l2 cdp args.l2_wrapper_hwp_type = "L2CompositeWithWorkerPrefetcher" args.kmh_align = True - + # Enable prefetch buffers for all hardware prefetchers in this config. + args.enable_pf_buffer = True assert not args.external_memory_system test_mem_mode = 'timing' diff --git a/configs/example/kmhv3.py b/configs/example/kmhv3.py index d82c6b98bd..f5b986bf17 100644 --- a/configs/example/kmhv3.py +++ b/configs/example/kmhv3.py @@ -182,6 +182,9 @@ def setKmhV3Params(args, system): assert not args.external_memory_system + # Enable prefetch buffers for all hardware prefetchers in this config. + args.enable_pf_buffer = True + # Set default bp_type based on ideal_kmhv3 flag # If user didn't specify bp_type, set default based on ideal_kmhv3 args.bp_type = 'DecoupledBPUWithBTB' diff --git a/docs/prefetch_cache_partition_plan.md b/docs/prefetch_cache_partition_plan.md new file mode 100644 index 0000000000..0da9fecd89 --- /dev/null +++ b/docs/prefetch_cache_partition_plan.md @@ -0,0 +1,82 @@ +# 预取专用Cache分区计划 + +## 目标 +为 Cache 增加一个只存放预取行的分区。预取的插入/替换/查询仅在该分区内进行,普通需求行不得占用。需求命中应能正常读取,是否迁移需定义。 + +## 现有关键位置 +- 访问/命中/未命中与填充逻辑位于 [src/mem/cache/base.cc](src/mem/cache/base.cc)。 +- Cache 前端类在 [src/mem/cache/cache.hh](src/mem/cache/cache.hh) 及实现 [src/mem/cache/cache.cc](src/mem/cache/cache.cc)。 +- 标签与组相联实现位于 [src/mem/cache/tags/base.hh](src/mem/cache/tags/base.hh)、[src/mem/cache/tags/base_set_assoc.hh](src/mem/cache/tags/base_set_assoc.hh)、[src/mem/cache/tags/base_set_assoc.cc](src/mem/cache/tags/base_set_assoc.cc)。 +- 预取元数据在填充时写入,逻辑在 base.cc 的填充路径。 + +## 待确认设计点 +1) 分区形态:每组预留若干路 vs 独立标签存储,倾向每组预留路以便接入现有索引。 +2) 需求命中是否迁移到主分区,或仅标记“已需求访问”并留在预取分区。 +3) 替换策略:预取分区独立策略实例,或复用同一策略但作用于子集路。 +4) 容量参数:按路数/字节/比例暴露,支持 0 关闭。 +5) 预取分区驱逐与主分区在一致性/写回语义上是否完全相同。 + +## 工作步骤 +1) **参数面**:新增 Cache 参数(开关与分区大小),配置层与 SCons 暴露。 +2) **块元数据**:扩展 `CacheBlk` 标记分区驻留与是否被需求提升,覆盖序列化/重置。 +3) **标签/组相联分区**:在 BaseSetAssoc(或派生)维护主/预取两套替换池;查找覆盖两分区;按分区选牺牲块。 +4) **查找路径**:`BaseCache::access` 同查两分区;命中预取分区时的处理与统计拆分。 +5) **填充/分配路由**:`handleFill`/`allocateBlock` 将预取响应路由至预取分区,禁止非预取分配进入;分区满时仅内部驱逐。 +6) **提升策略**:需求命中预取分区时决定“迁移到主分区”或“原地标记需求”,并更新替换元数据。 +7) **驱逐/写回路径**:`evictBlock`、`doWritebacks`、CleanEvict 保持分区隔离但语义一致,统计独立。 +8) **预取队列**:确保 MSHR 分配对预取请求打标,便于填充分路由识别;必要时增加断言/计数。 +9) **统计/探针**:预取分区占用、命中、填充、驱逐、提升、DOA 等计数,必要时调整暖身处理。 +10) **配置与文档**:更新 Cache.py/PrefetcherConfig.py 等配置,补充用户文档与调优说明。 +11) **测试**:微基准覆盖仅预取、混合需求+预取、启用分区的 checkpoint/恢复、一致性与失效场景。 + +## 风险/关注点 +- 替换策略需支持双池。 +- 分区化可能影响延迟建模(tag/data 端口计数)。 +- 序列化需保留分区与替换状态。 +- 默认关闭以确保兼容性。 + +## 讨论更新(方案锁定后续步骤) +基于已确认的设计选项: +1) 分区形态:每组预留若干路。 +2) 需求命中仅标记“已需求访问”并留在预取分区,不迁移。 +3) 替换策略:复用同一替换策略实例,但作用于分区内子集路。 +4) 容量参数:按比例暴露,支持 0 关闭。 +5) 驱逐语义:预取分区与主分区在一致性/写回语义上完全相同。 + +### 下一步工作规划 +1) **参数落地**:新增比例型分区参数与开关;Python 配置与 SCons 接口同步;默认 0 关闭。 +2) **块元数据**:为 `CacheBlk` 增加分区驻留标记与“已需求访问”标志;补齐序列化/重置逻辑。 +3) **分区化标签/替换**:在 BaseSetAssoc(或派生)实现按组预留路的双分区池,复用同一替换策略实例但限制候选路集;提供按分区选 victim 的接口。 +4) **访问路径**:`BaseCache::access` 查两分区;命中预取分区时仅标记“已需求访问”,不迁移;统计分拆。 +5) **填充/分配路由**:`handleFill`/`allocateBlock` 将预取响应放入预取分区;普通需求不得占用;分区满时仅内部驱逐,语义与主分区一致。 +6) **驱逐与写回**:复用现有写回/一致性流程,但保持分区隔离与统计独立;确保 CleanEvict/Writeback 行为一致。 +7) **预取标记链路**:确保 MSHR/请求链路对预取打标,便于填充分区路由;缺标时记录计数/告警。 +8) **统计**:新增占用、命中、填充、驱逐、已需求访问计数(含 DOA),按分区拆分;若使用比例参数,记录实际预留路数。 +9) **测试计划**: + - 仅预取流:验证预取分区被使用且需求不会进入。 + - 混合流:需求命中预取分区仅标记不迁移;容量隔离有效。 + - 恢复/一致性:checkpoint/restore,CleanEvict/Writeback 语义一致性。 + - 关闭模式:参数为 0 时行为与原先一致。 + + ## 讨论更新(追加) + 新的设计点修正: + 1) 分区形态:每组预留若干路,按比例参数计算预留路数;向下取整;若不足 1 路则该组预留 0(等价于禁用)。 + 2) 需求命中:仅标记“已需求访问”,不迁移出预取分区。 + 3) 替换策略:预取分区使用独立的替换策略实例(与主分区解耦)。 + 4) 容量参数:按比例暴露,支持 0 关闭;组数若非幂次则报错(仅在参数为幂次组数时工作)。 + 5) 驱逐语义:预取分区与主分区在一致性/写回上完全一致。 + + ### 下一步工作规划(根据新设计点) + 1) **参数实现**:比例型分区参数,向下截断到整路;不足 1 路取 0;若组数非幂次直接报错;默认 0 关闭。配置层与 SCons 同步。 + 2) **块元数据**:`CacheBlk` 增加分区驻留与“已需求访问”标志,含序列化/重置。 + 3) **分区化标签/替换**:BaseSetAssoc(或派生)维护主/预取双池,预取池使用独立替换策略实例;按组预留路数由比例计算;victim 选择限定在各自池内。 + 4) **访问路径**:`BaseCache::access` 查两分区;命中预取分区仅标记已需求访问,不迁移;统计分拆。 + 5) **填充/分配路由**:`handleFill`/`allocateBlock` 将预取响应放入预取分区,普通需求不得占用;预取分区满时仅内部驱逐,语义与主分区一致。 + 6) **驱逐与写回**:复用现有流程,保持分区隔离与统计独立,保证 CleanEvict/Writeback 语义一致。 + 7) **预取标记链路**:MSHR/请求需带预取标记,缺标计数/告警,确保填充分路由正确。 + 8) **统计**:分区占用、命中、填充、驱逐、已需求访问、DOA;记录比例参数实际转换的预留路数。 + 9) **测试**: + - 仅预取流:验证预取分区启用与容量隔离。 + - 混合流:需求命中不迁移,统计正确。 + - 非幂次组数:应触发报错覆盖。 + - 关闭模式:参数为 0 时回归旧行为。 diff --git a/src/cpu/o3/dyn_inst.cc b/src/cpu/o3/dyn_inst.cc index 2b572ee556..51efc2b8aa 100644 --- a/src/cpu/o3/dyn_inst.cc +++ b/src/cpu/o3/dyn_inst.cc @@ -63,7 +63,7 @@ namespace o3 DynInst::DynInst(const Arrays &arrays, const StaticInstPtr &static_inst, const StaticInstPtr &_macroop, InstSeqNum seq_num, CPU *_cpu) : seqNum(seq_num), staticInst(static_inst), - xsMeta(new XsDynInstMeta()), + xsMeta(new XsDynInstMeta(seq_num)), cpu(_cpu), _numSrcs(arrays.numSrcs), _numDests(arrays.numDests), _flatDestIdx(arrays.flatDestIdx), _destIdx(arrays.destIdx), diff --git a/src/cpu/o3/dyn_inst_xsmeta.hh b/src/cpu/o3/dyn_inst_xsmeta.hh index 60c1b2dbdf..9ee50a3e1a 100644 --- a/src/cpu/o3/dyn_inst_xsmeta.hh +++ b/src/cpu/o3/dyn_inst_xsmeta.hh @@ -59,12 +59,14 @@ namespace o3 class XsDynInstMeta : public RefCounted { -public: - bool squashed; - Addr instAddr; + public: + bool squashed; + Addr instAddr; + InstSeqNum seqNum; -public: - XsDynInstMeta(): squashed(false),instAddr(0) {} + public: + XsDynInstMeta(): squashed(false), instAddr(0), seqNum(0) {} + XsDynInstMeta(InstSeqNum seq): squashed(false), instAddr(0), seqNum(seq) {} }; using XsDynInstMetaPtr = RefCountingPtr; diff --git a/src/mem/cache/base.cc b/src/mem/cache/base.cc index 74fa8a698a..3d14d1cf76 100644 --- a/src/mem/cache/base.cc +++ b/src/mem/cache/base.cc @@ -514,6 +514,13 @@ BaseCache::handleTimingReqMiss(PacketPtr pkt, MSHR *mshr, CacheBlk *blk, pkt->pfSource = mshr->getPFSource(); pkt->pfDepth = mshr->getPFDepth(); + // Demand request merging into prefetch-only MSHR + if (pkt->isDemand()) { + stats.demandMergedIntoPfMSHR++; + DPRINTF(Cache, "Demand request %#lx merged into prefetch MSHR\n", + pkt->getAddr()); + } + } else if (mshr->hasFromCPU()) { // no pkt in mshr originated from cache; all of them are from cpu pkt->coalescingMSHR = true; @@ -2887,6 +2894,12 @@ BaseCache::CacheStats::CacheStats(BaseCache &c) "number of squashed dead block replacements"), ADD_STAT(squashedLiveBlockReplacements, statistics::units::Count::get(), "number of squashed live block replacements"), + ADD_STAT(pfMergedWithDemand, statistics::units::Count::get(), + "number of MSHR completions where prefetch was merged with demand"), + ADD_STAT(pfOnlyFill, statistics::units::Count::get(), + "number of MSHR completions with only prefetch (no demand merge)"), + ADD_STAT(demandMergedIntoPfMSHR, statistics::units::Count::get(), + "number of demand requests that merged into prefetch MSHR"), ADD_STAT(squashedDemandHits, statistics::units::Count::get(), "number of squashed inst block demand hits"), ADD_STAT(loadTagReadFails, statistics::units::Count::get(), diff --git a/src/mem/cache/base.hh b/src/mem/cache/base.hh index 41cb5f4f52..7639d9d247 100644 --- a/src/mem/cache/base.hh +++ b/src/mem/cache/base.hh @@ -1313,6 +1313,13 @@ class BaseCache : public ClockedObject, public CacheAccessor /** Number of replacements of blocks from squashed inst but reused. */ statistics::Scalar squashedLiveBlockReplacements; + /** Number of MSHR completions where prefetch was merged with demand */ + statistics::Scalar pfMergedWithDemand; + /** Number of MSHR completions with only prefetch (no demand merge) */ + statistics::Scalar pfOnlyFill; + /** Number of demand requests that merged into prefetch MSHR */ + statistics::Scalar demandMergedIntoPfMSHR; + /** Number of demand hits that accessed squashed inst blocks. */ statistics::Scalar squashedDemandHits; diff --git a/src/mem/cache/cache.cc b/src/mem/cache/cache.cc index 410fbdaeba..f59d0731df 100644 --- a/src/mem/cache/cache.cc +++ b/src/mem/cache/cache.cc @@ -988,6 +988,12 @@ Cache::serviceMSHRTargets(MSHR *mshr, const PacketPtr pkt, CacheBlk *blk) blk->setPrefetched(); blk->setXsMetadata(pkt->req->getXsMetadata()); DPRINTF(Cache, "Marking block as prefetched from prefetcher %i\n", blk->getXsMetadata().prefetchSource); + stats.pfOnlyFill++; // Pure prefetch fill (no demand merge) + } else if (blk && from_core && from_pref) { + // Prefetch was merged with demand - won't be marked as prefetched + stats.pfMergedWithDemand++; + DPRINTF(Cache, "Prefetch merged with demand for %#lx - not marking as prefetched\n", + blk->getTag()); } if (!mshr->hasLockedRMWReadTarget()) { diff --git a/src/mem/cache/prefetch/Prefetcher.py b/src/mem/cache/prefetch/Prefetcher.py index e361f93d5a..eea10b7229 100644 --- a/src/mem/cache/prefetch/Prefetcher.py +++ b/src/mem/cache/prefetch/Prefetcher.py @@ -77,6 +77,7 @@ class BasePrefetcher(ClockedObject): on_write = Param.Bool(True, "Notify prefetcher on writes") on_data = Param.Bool(True, "Notify prefetcher on data accesses") on_inst = Param.Bool(True, "Notify prefetcher on instruction accesses") + prefetch_train = Param.Bool(True, "Allow upstream PF req train low level Prefetcher") prefetch_on_access = Param.Bool(False, "Notify the hardware prefetcher on every access (not just misses)") prefetch_on_pf_hit = Param.Bool(False, @@ -88,6 +89,8 @@ class BasePrefetcher(ClockedObject): is_sub_prefetcher = Param.Bool(False, "Is this a sub-prefetcher") + training_buffer_size = Param.Unsigned(8, + "Maximum number of training requests buffered per cycle") def __init__(self, **kwargs): super().__init__(**kwargs) @@ -186,6 +189,8 @@ class QueuedPrefetcher(BasePrefetcher): that can be throttled depending on the accuracy of the prefetcher.") max_pfahead_recv = Param.Int(1,"Maximum number of pfahead received") + use_pf_buffer = Param.Bool(False, "use prefetch buffer to filter prefetches") + max_pf_buffer_size = Param.Int(16, "size of prefetch buffer") class XSStridePrefetcher(QueuedPrefetcher): @@ -199,23 +204,36 @@ class XSStridePrefetcher(QueuedPrefetcher): on_write = False on_data = True on_inst = False + region_size = Param.Int(1024, "region size") use_xs_depth = Param.Bool(True,"use xs rtl stride depth") fuzzy_stride_matching = Param.Bool(False, "Match stride with fuzzy condition") short_stride_thres = Param.Unsigned(512, "Ignore short strides when there are long strides (Bytes)") stride_dyn_depth = Param.Bool(False, "Dynamic depth of stride table") stride_entries = Param.MemorySize("10", "Stride Entries") - stride_indexing_policy = Param.BaseIndexingPolicy( + stride_unique_indexing_policy = Param.BaseIndexingPolicy( SetAssociative( entry_size=1, assoc=Parent.stride_entries, size=Parent.stride_entries), "Indexing policy of stride table" ) - stride_replacement_policy = Param.BaseReplacementPolicy( + stride_unique_replacement_policy = Param.BaseReplacementPolicy( + TreePLRURP(num_leaves=Parent.stride_entries), + "Replacement policy of stride table" + ) + stride_redundant_indexing_policy = Param.BaseIndexingPolicy( + SetAssociative( + entry_size=1, + assoc=Parent.stride_entries, + size=Parent.stride_entries), + "Indexing policy of stride table" + ) + stride_redundant_replacement_policy = Param.BaseReplacementPolicy( TreePLRURP(num_leaves=Parent.stride_entries), "Replacement policy of stride table" ) + use_redundant_table = Param.Bool(False, "Use redundant stride table") fuzzy_stride_matching = Param.Bool(False, "Match stride with fuzzy condition") # stride black list @@ -274,7 +292,7 @@ class XsStreamPrefetcher(QueuedPrefetcher): type = "XsStreamPrefetcher" cxx_class = "gem5::prefetch::XsStreamPrefetcher" cxx_header = "mem/cache/prefetch/xs_stream.hh" - + region_size = Param.Int(1024, "region size") use_virtual_addresses = True prefetch_on_pf_hit = True on_read = True @@ -296,7 +314,7 @@ class XsStreamPrefetcher(QueuedPrefetcher): "Indexing policy of active generation table" ) xs_stream_replacement_policy = Param.BaseReplacementPolicy( - LRURP(), + TreePLRURP(num_leaves = Parent.xs_stream_entries), "Replacement policy of active generation table" ) @@ -768,11 +786,23 @@ class XSVirtualLargeBOP(BOPPrefetcher): delay_queue_size = 16 delay_queue_cycles = 300 - offsets = [x for i in [ - 1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 15, 16, 18, 20, 24, 25, 27, 30, 32, 36, 40, 45, 48, - 50, 54, 60, 64, 72, 75, 80, 81, 90, 96, 100, 108, 120, 125, 128, 135, 144, 150, 160, 162, 180, 192, 200, 216, - 225, 240, 243, 250 - ] for x in (i, -i)] + [-256] + offsets = [ + -117, -147, -91, 117, 147, 91, + -256, -250, -243, -240, -225, -216, -200, + -192, -180, -162, -160, -150, -144, -135, -128, + -125, -120, -108, -100, -96, -90, -81, -80, + -75, -72, -64, -60, -54, -50, -48, -45, + -40, -36, -32, -30, -27, -25, -24, -20, + -18, -16, -15, -12, -10, -9, -8, -6, + -5, -4, -3, -2, -1, + 1, 2, 3, 4, 5, 6, 8, + 9, 10, 12, 15, 16, 18, 20, 24, + 25, 27, 30, 32, 36, 40, 45, 48, + 50, 54, 60, 64, 72, 75, 80, 81, + 90, 96, 100, 108, 120, 125, 128, 135, + 144, 150, 160, 162, 180, 192, 200, 216, + 225, 240, 243, 250 + ] class SmallBOPPrefetcher(BOPPrefetcher): score_max = 31 @@ -1019,6 +1049,11 @@ class XSCompositePrefetcher(QueuedPrefetcher): on_inst = False region_size = Param.Int(1024, "region size") + + # TrainFilter configuration + enable_train_filter = Param.Bool(True, "Enable TrainFilter for ROB-order training") + training_buffer_size = 8 + # filter table (full-assoc) filter_entries = Param.MemorySize("16", "num of filter table entries") filter_indexing_policy = Param.BaseIndexingPolicy( @@ -1059,6 +1094,57 @@ class XSCompositePrefetcher(QueuedPrefetcher): size=Parent.re_act_entries), "Indexing policy of recently active generation table" ) + sms_filter_entries = Param.MemorySize( + "16", + "num of pattern history table entries" + ) + sms_filter_assoc = Param.Int(16, "Associativity of the pattern history table") + sms_filter_indexing_policy = Param.BaseIndexingPolicy( + SetAssociative( + entry_size=1, + assoc=Parent.sms_filter_assoc, + size=Parent.sms_filter_entries), + "Indexing policy of filter table" + ) + sms_filter_replacement_policy = Param.BaseReplacementPolicy( + TreePLRURP(num_leaves=Parent.sms_filter_entries), + "Replacement policy of filter table" + ) + + stridestream_L1_filter_entries = Param.MemorySize( + "16", + "num of pattern history table entries" + ) + stridestream_L1_filter_assoc = Param.Int(16, "Associativity of the pattern history table") + stridestream_L1_filter_indexing_policy = Param.BaseIndexingPolicy( + SetAssociative( + entry_size=1, + assoc=Parent.stridestream_L1_filter_assoc, + size=Parent.stridestream_L1_filter_entries), + "Indexing policy of filter table" + ) + stridestream_L1_filter_replacement_policy = Param.BaseReplacementPolicy( + TreePLRURP(num_leaves=Parent.stridestream_L1_filter_entries), + "Replacement policy of filter table" + ) + + stridestream_L2L3_filter_entries = Param.MemorySize( + "16", + "num of pattern history table entries" + ) + stridestream_L2L3_filter_assoc = Param.Int(16, "Associativity of the pattern history table") + stridestream_L2L3_filter_indexing_policy = Param.BaseIndexingPolicy( + SetAssociative( + entry_size=1, + assoc=Parent.stridestream_L2L3_filter_assoc, + size=Parent.stridestream_L2L3_filter_entries), + "Indexing policy of filter table" + ) + stridestream_L2L3_filter_replacement_policy = Param.BaseReplacementPolicy( + TreePLRURP(num_leaves=Parent.stridestream_L2L3_filter_entries), + "Replacement policy of filter table" + ) + vaddr_hash_width = Param.Int(5, "Width of virtual address hash") re_act_replacement_policy = Param.BaseReplacementPolicy( FIFORP(), "Replacement policy of recently active generation table" @@ -1110,6 +1196,7 @@ class XSCompositePrefetcher(QueuedPrefetcher): "Indexing policy of pattern history table" ) pht_replacement_policy = Param.BaseReplacementPolicy( + # TreePLRURP(num_leaves=Parent.pht_entries), LRURP(), "Replacement policy of pattern history table" ) @@ -1135,6 +1222,7 @@ class XSCompositePrefetcher(QueuedPrefetcher): "Small BOP used in composite prefetcher ") bop_learned = Param.BasePrefetcher(LearnedBOPPrefetcher(is_sub_prefetcher=True), "Learned BOP used in composite prefetcher ") + bop_pf_level = Param.Int(2, "L1 BOP prefetch target level") spp = Param.BasePrefetcher(SignaturePathPrefetcher(is_sub_prefetcher=True), "SPP used in composite prefetcher") ipcp = Param.IPCPrefetcher(IPCPrefetcher(use_rrf = False, is_sub_prefetcher=True), "") diff --git a/src/mem/cache/prefetch/SConscript b/src/mem/cache/prefetch/SConscript index cc4fa206c9..2e074e2fdd 100644 --- a/src/mem/cache/prefetch/SConscript +++ b/src/mem/cache/prefetch/SConscript @@ -89,3 +89,4 @@ Source('composite_with_worker.cc') Source('l2_composite_with_worker.cc') Source('despacito_stream.cc') Source('forwarder.cc') +Source('prefetch_filter.cc') diff --git a/src/mem/cache/prefetch/base.cc b/src/mem/cache/prefetch/base.cc index 77572420aa..0fd233b369 100644 --- a/src/mem/cache/prefetch/base.cc +++ b/src/mem/cache/prefetch/base.cc @@ -46,6 +46,7 @@ #include "mem/cache/prefetch/base.hh" #include +#include #include "base/intmath.hh" #include "debug/HWPrefetch.hh" @@ -111,7 +112,78 @@ Base::PrefetchInfo::PrefetchInfo(PrefetchInfo const &pfi, Addr addr) data(nullptr),data_ptr(nullptr) { } +Base::PrefetchInfo::PrefetchInfo(PrefetchInfo_old const &pfi) + : address(pfi.address), pc(pfi.pc), requestorId(pfi.requestorId), + validPC(pfi.validPC), secure(pfi.secure), size(pfi.size), + write(pfi.write), paddress(pfi.paddress), cacheMiss(pfi.cacheMiss), + data(nullptr),data_ptr(nullptr) +{ +} +Base::PrefetchInfo_old::PrefetchInfo_old(PacketPtr pkt, Addr addr, bool miss) + : address(addr), pc(pkt->req->hasPC() ? pkt->req->getPC() : 0), + requestorId(pkt->req->requestorId()), validPC(pkt->req->hasPC()), + secure(pkt->isSecure()), size(pkt->req->getSize()), write(pkt->isWrite()), + paddress(pkt->req->getPaddr()), cacheMiss(miss) +{ + unsigned int req_size = pkt->req->getSize(); + if (!write && miss) { + data = nullptr; + data_ptr = nullptr; + } else if (pkt->isStorePFTrain()) { + data = nullptr; + data_ptr = nullptr; + } else { + data = new uint8_t[req_size]; + Addr offset = pkt->req->getPaddr() - pkt->getAddr(); + std::memcpy(data, &(pkt->getConstPtr()[offset]), req_size); + data_ptr=(uint64_t*)pkt->getPtr(); + } +} + +Base::PrefetchInfo_old::PrefetchInfo_old( + PacketPtr pkt, Addr addr, bool miss, + Request::XsMetadata xsMeta +) : address(addr), pc(pkt->req->hasPC() ? pkt->req->getPC() : 0), + requestorId(pkt->req->requestorId()), validPC(pkt->req->hasPC()), + secure(pkt->isSecure()), size(pkt->req->getSize()), write(pkt->isWrite()), + paddress(pkt->req->getPaddr()), cacheMiss(miss), xsMetadata(xsMeta) +{ + unsigned int req_size = pkt->req->getSize(); + if (!write && miss) { + data = nullptr; + data_ptr = nullptr; + } else if (pkt->isStorePFTrain()) { + data = nullptr; + data_ptr = nullptr; + } else { + data = new uint8_t[req_size]; + Addr offset = pkt->req->getPaddr() - pkt->getAddr(); + std::memcpy(data, &(pkt->getConstPtr()[offset]), req_size); + data_ptr=(uint64_t*)pkt->getPtr(); + } +} +Base::PrefetchInfo_old::PrefetchInfo_old(PrefetchInfo_old const &other) + : address(other.address), pc(other.pc), requestorId(other.requestorId), + validPC(other.validPC), secure(other.secure), size(other.size), + write(other.write), paddress(other.paddress), cacheMiss(other.cacheMiss), + data(nullptr),data_ptr(nullptr) +{ +} +Base::PrefetchInfo_old::PrefetchInfo_old(PrefetchInfo_old const &pfi, Addr addr) + : address(addr), pc(pfi.pc), requestorId(pfi.requestorId), + validPC(pfi.validPC), secure(pfi.secure), size(pfi.size), + write(pfi.write), paddress(pfi.paddress), cacheMiss(pfi.cacheMiss), + data(nullptr),data_ptr(nullptr) +{ +} +Base::PrefetchInfo_old::PrefetchInfo_old(PrefetchInfo const &pfi) + : address(pfi.address), pc(pfi.pc), requestorId(pfi.requestorId), + validPC(pfi.validPC), secure(pfi.secure), size(pfi.size), + write(pfi.write), paddress(pfi.paddress), cacheMiss(pfi.cacheMiss), + data(nullptr),data_ptr(nullptr) +{ +} void Base::PrefetchListener::notify(const PacketPtr &pkt) { @@ -126,12 +198,16 @@ Base::PrefetchListener::notify(const PacketPtr &pkt) Base::Base(const BasePrefetcherParams &p) : ClockedObject(p), - listeners(), isSubPrefetcher(p.is_sub_prefetcher), + listeners(), + trainingBufferSize(p.training_buffer_size), + cycleEvent([this]{ processCycle(); }, name()), // TrainFilter cycle event + isSubPrefetcher(p.is_sub_prefetcher), archDBer(p.arch_db), blkSize(p.block_size), lBlkSize(floorLog2(blkSize)), onMiss(p.on_miss), onRead(p.on_read), onWrite(p.on_write), onData(p.on_data), onInst(p.on_inst), requestorId(p.sys->getRequestorId(this)), pageBytes(p.page_bytes), + prefetchTrain(p.prefetch_train), prefetchOnAccess(p.prefetch_on_access), prefetchOnPfHit(p.prefetch_on_pf_hit), useVirtualAddresses(p.use_virtual_addresses), @@ -196,6 +272,10 @@ Base::StatGroup::StatGroup(statistics::Group *parent) "number of prefetches hitting in a MSHR"), ADD_STAT(pfHitInWB, statistics::units::Count::get(), "number of prefetches hit in the Write Buffer"), + ADD_STAT(pfGenerated, statistics::units::Count::get(), + "number of prefetch requests generated by prefetcher"), + ADD_STAT(pfFiltered, statistics::units::Count::get(), + "number of prefetch requests filtered before issuing"), ADD_STAT(pfLate, statistics::units::Count::get(), "number of late prefetches (hitting in cache, MSHR or WB)") { @@ -243,6 +323,10 @@ Base::observeAccess(const PacketPtr &pkt, bool miss) const bool read = pkt->isRead(); bool inv = pkt->isInvalidate(); + // Filter L1 prefetcher requests from training L2 prefetcher + if (pkt->req->isPrefetch() && !prefetchTrain) { + return false; + } if (!miss) { if (prefetchOnPfHit) return hasEverBeenPrefetched(pkt->getAddr(), pkt->isSecure()); @@ -385,14 +469,56 @@ Base::probeNotify(const PacketPtr &pkt, bool miss) if (!useVirtualAddresses || pkt->req->hasVaddr()) { // condition1: useVirtualAddresses && pkt->req->hasVaddr() // condition2: !useVirtualAddresses - PrefetchInfo pfi(pkt, pkt->req->hasVaddr() ? pkt->req->getVaddr() : pkt->req->getPaddr(), miss, - Request::XsMetadata(pf_source, pf_depth)); - pfi.setReqAfterSquash(squashMark); - pfi.setEverPrefetched(hasEverBeenPrefetched(pkt->getAddr(), pkt->isSecure())); - pfi.setPfFirstHit(!miss && hasBeenPrefetched(pkt->getAddr(), pkt->isSecure())); - pfi.setPfHit(!miss && hasEverBeenPrefetched(pkt->getAddr(), pkt->isSecure())); + + Addr addr = pkt->req->hasVaddr() ? pkt->req->getVaddr() : pkt->req->getPaddr(); + Request::XsMetadata xsMetadata(pf_source, pf_depth); + + // Query and save all state information needed for training + bool everPrefetched = hasEverBeenPrefetched(pkt->getAddr(), pkt->isSecure()); + bool pfFirstHit = !miss && hasBeenPrefetched(pkt->getAddr(), pkt->isSecure()); + bool pfHit = !miss && everPrefetched; + bool currentSquashMark = squashMark; squashMark = false; - notify(pkt, pfi); + + // TrainFilter: Collect training requests into temporary buffers + if (useTrainingBuffer()) { + // Extract ROB sequence number from packet metadata + InstSeqNum seqNum = getSeqNum(pkt); + Addr blockAddr = getBlockAddr(addr); + bool isLoad = isLoadRequest(pkt); + + // Collect into Load or Store temporary buffer based on request type + if (isLoad) { + currentCycleLoads.emplace_back( + pkt, addr, miss, xsMetadata, + everPrefetched, pfFirstHit, pfHit, currentSquashMark, + seqNum, blockAddr, isLoad + ); + DPRINTF(HWPrefetch, "TrainFilter: Collected Load [seq=%lu, blk=%#x]\n", + seqNum, blockAddr); + } else { + currentCycleStores.emplace_back( + pkt, addr, miss, xsMetadata, + everPrefetched, pfFirstHit, pfHit, currentSquashMark, + seqNum, blockAddr, isLoad + ); + DPRINTF(HWPrefetch, "TrainFilter: Collected Store [seq=%lu, blk=%#x]\n", + seqNum, blockAddr); + } + + if (!cycleEvent.scheduled()) { + schedule(cycleEvent, clockEdge(Cycles(1))); + DPRINTF(HWPrefetch, "TrainFilter: Scheduled processCycle for next cycle\n"); + } + } else { + // When not using buffer, create PrefetchInfo immediately and train + PrefetchInfo pfi(pkt, addr, miss, xsMetadata); + pfi.setReqAfterSquash(currentSquashMark); + pfi.setEverPrefetched(everPrefetched); + pfi.setPfFirstHit(pfFirstHit); + pfi.setPfHit(pfHit); + notify(pkt, pfi); + } } else { DPRINTF(HWPrefetch, "Skip req addr %x, has vaddr: %i\n", pkt->req->hasVaddr() ? pkt->req->getVaddr() : pkt->req->getPaddr(), pkt->req->hasVaddr()); @@ -403,6 +529,175 @@ Base::probeNotify(const PacketPtr &pkt, bool miss) } } +void +Base::processCycle() +{ + DPRINTF(HWPrefetch, "=== TrainFilter Cycle @ Tick %lu ===\n", curTick()); + + // Step 1: Flush previous cycle's collected requests into trainingBuffer + if (!currentCycleLoads.empty() || !currentCycleStores.empty()) { + flushCurrentCycleRequests(); + } + + // Step 2: Train one request from trainingBuffer (if available) + if (!trainingBuffer.empty()) { + processTraining(); + } + + bool hasWork = !currentCycleLoads.empty() || + !currentCycleStores.empty() || + !trainingBuffer.empty(); + + if (hasWork && !cycleEvent.scheduled()) { + schedule(cycleEvent, clockEdge(Cycles(1))); + DPRINTF(HWPrefetch, "TrainFilter: Rescheduled (pending work: %d loads, %d stores, %d in buffer)\n", + currentCycleLoads.size(), currentCycleStores.size(), trainingBuffer.size()); + } else if (!hasWork) { + DPRINTF(HWPrefetch, "TrainFilter: No work remaining, stopping cycle event\n"); + } +} + +void +Base::flushCurrentCycleRequests() +{ + if (currentCycleLoads.empty() && currentCycleStores.empty()) { + return; + } + + DPRINTF(HWPrefetch, "TrainFilter: Flushing %d Loads, %d Stores\n", + currentCycleLoads.size(), currentCycleStores.size()); + + // Step 1: Sort Load group by ROB order (oldest first) + std::sort(currentCycleLoads.begin(), currentCycleLoads.end(), + [](const TrainingRequest &a, const TrainingRequest &b) { + return a.seqNum < b.seqNum; // Ascending order (oldest first) + }); + + // Step 2: Sort Store group by ROB order (oldest first) + std::sort(currentCycleStores.begin(), currentCycleStores.end(), + [](const TrainingRequest &a, const TrainingRequest &b) { + return a.seqNum < b.seqNum; + }); + + // Step 3: Merge into [Loads..., Stores...] sequence + std::vector sortedRequests; + sortedRequests.reserve(currentCycleLoads.size() + currentCycleStores.size()); + + for (auto &req : currentCycleLoads) { + sortedRequests.push_back(std::move(req)); + } + + for (auto &req : currentCycleStores) { + sortedRequests.push_back(std::move(req)); + } + + DPRINTF(HWPrefetch, "TrainFilter: Reordered sequence: "); + for (const auto &req : sortedRequests) { + DPRINTFR(HWPrefetch, "[%s%lu] ", req.isLoad ? "L" : "S", req.seqNum); + } + DPRINTFR(HWPrefetch, "\n"); + + // Step 4: Filter and insert into trainingBuffer + for (auto &req : sortedRequests) { + Addr blockAddr = req.blockAddr; + + if (trainingBufferBlockAddrs.count(blockAddr) > 0) { + DPRINTF(HWPrefetch, " TrainFilter: Drop [%s%lu, %#x] - in buffer\n", + req.isLoad ? "L" : "S", req.seqNum, blockAddr); + continue; + } + + if (trainingBuffer.size() >= trainingBufferSize) { + DPRINTF(HWPrefetch, " TrainFilter: Drop [%s%lu, %#x] - buffer full\n", + req.isLoad ? "L" : "S", req.seqNum, blockAddr); + continue; + } + + bool isLoad = req.isLoad; + InstSeqNum seqNum = req.seqNum; + + trainingBuffer.push_back(std::move(req)); + trainingBufferBlockAddrs.insert(blockAddr); + + DPRINTF(HWPrefetch, " TrainFilter: Enqueue [%s%lu, %#x] (buffer: %d)\n", + isLoad ? "L" : "S", seqNum, blockAddr, + trainingBuffer.size()); + } + + currentCycleLoads.clear(); + currentCycleStores.clear(); +} + +void +Base::processTraining() +{ + if (trainingBuffer.empty()) { + return; + } + + TrainingRequest &req = trainingBuffer.front(); + + DPRINTF(HWPrefetch, ">>> TrainFilter: Training [%s%lu, %#x] (remaining: %d)\n", + req.isLoad ? "L" : "S", req.seqNum, req.blockAddr, + trainingBuffer.size() - 1); + + PacketPtr temp_pkt = new Packet(req.req, req.cmd); + + bool isWrite = temp_pkt->isWrite(); + bool willAccessData = (isWrite || !req.miss) && !temp_pkt->isStorePFTrain(); + + if (req.dataCopy != nullptr) { + temp_pkt->dataDynamic(req.dataCopy); + + const_cast(req).dataCopy = nullptr; + + DPRINTF(HWPrefetch, " TrainFilter: Packet with data (%d bytes)\n", req.dataSize); + } else if (willAccessData) { + DPRINTF(HWPrefetch, " TrainFilter: WARNING - Creating dummy data buffer " + "(original packet had no data, miss=%d, isWrite=%d)\n", + req.miss, isWrite); + + uint8_t *dummyData = new uint8_t[req.dataSize]; + std::memset(dummyData, 0, req.dataSize); + temp_pkt->dataDynamic(dummyData); + } else { + DPRINTF(HWPrefetch, " TrainFilter: Packet without data (miss=%d, isWrite=%d)\n", + req.miss, isWrite); + } + + PrefetchInfo pfi(temp_pkt, req.addr, req.miss, req.xsMetadata); + pfi.setReqAfterSquash(req.squashMark); + pfi.setEverPrefetched(req.everPrefetched); + pfi.setPfFirstHit(req.pfFirstHit); + pfi.setPfHit(req.pfHit); + notify(temp_pkt, pfi); + + delete temp_pkt; + + trainingBufferBlockAddrs.erase(req.blockAddr); + + trainingBuffer.pop_front(); +} + +InstSeqNum +Base::getSeqNum(const PacketPtr &pkt) const +{ + // Try to get seqNum from XsMeta data + if (pkt->req->getXsMetadata().validXsMetadata && + pkt->req->getXsMetadata().instXsMetadata) { + return pkt->req->getXsMetadata().instXsMetadata->seqNum; + } + + panic("cannot get valid seqNum\n"); + +} + +bool +Base::isLoadRequest(const PacketPtr &pkt) const +{ + return pkt->isRead() && !pkt->isWrite(); +} + void Base::coreDirectAddrNotify(const PacketPtr& pkt) { diff --git a/src/mem/cache/prefetch/base.hh b/src/mem/cache/prefetch/base.hh index 0e9a2bf4e2..2ba250fa48 100644 --- a/src/mem/cache/prefetch/base.hh +++ b/src/mem/cache/prefetch/base.hh @@ -47,6 +47,9 @@ #define __MEM_CACHE_PREFETCH_BASE_HH__ #include +#include +#include +#include #include "arch/generic/tlb.hh" #include "base/compiler.hh" @@ -77,6 +80,11 @@ struct CustomPfInfo class Base : public ClockedObject { + public: + struct PFtriggerInfo; + class PrefetchInfo; + class PrefetchInfo_old; + private: friend class PrefetcherForwarder; class PrefetchListener : public ProbeListenerArgBase { @@ -99,13 +107,56 @@ class Base : public ClockedObject std::vector listeners; public: - + struct PFtriggerInfo{ + PacketPtr pkt; + std::unique_ptr pfi_old; + PrefetchSourceType pfSourceType; + PFtriggerInfo() : pkt(nullptr), pfi_old(nullptr), pfSourceType(PrefetchSourceType::PF_NONE) {} + PFtriggerInfo(PacketPtr p, const PrefetchInfo &a) + : pkt(p ? new Packet(p, false, false) : nullptr), + pfi_old(std::make_unique(a)), pfSourceType(PrefetchSourceType::PF_NONE) {} + PFtriggerInfo(const PFtriggerInfo &other) + : pkt(other.pkt ? new Packet(other.pkt, false, false) : nullptr), + pfi_old(other.pfi_old ? std::make_unique(*(other.pfi_old)) : nullptr), + pfSourceType(other.pfSourceType) {} + PFtriggerInfo& operator=(const PFtriggerInfo &other) + { + if (this != &other) { + delete pkt; + pkt = other.pkt ? new Packet(other.pkt, false, false) : nullptr; + pfi_old = std::make_unique(*(other.pfi_old)); + pfSourceType = other.pfSourceType; + } + return *this; + } + // PFtriggerInfo(PFtriggerInfo &&other) noexcept + // : pkt(other.pkt), pfi_old(std::move(other.pfi_old)) + // { + // other.pkt = nullptr; + // } + // PFtriggerInfo& operator=(PFtriggerInfo &&other) noexcept + // { + // if (this != &other) { + // delete pkt; + // pkt = other.pkt; + // pfi_old = std::move(other.pfi_old); + // other.pkt = nullptr; + // } + // return *this; + // } + ~PFtriggerInfo() + { + delete pkt; + pfi_old.reset(); + } + }; /** * Class containing the information needed by the prefetch to train and * generate new prefetch requests. */ class PrefetchInfo { + friend class PrefetchInfo_old; /** The address used to train and generate prefetches */ Addr address; /** The program counter that generated this address. */ @@ -327,6 +378,7 @@ class Base : public ClockedObject * @param addr the address value of the new object */ PrefetchInfo(PrefetchInfo const &pfi, Addr addr); + PrefetchInfo(PrefetchInfo_old const &pfi); ~PrefetchInfo() { @@ -334,9 +386,405 @@ class Base : public ClockedObject } bool lastPfLate{false}; + mutable PFtriggerInfo trigger_info{}; + void setTriggerInfo(const PacketPtr &pkt) const { + trigger_info = PFtriggerInfo(pkt, *this); + } + void setTriggerInfo_PFsrc(const PrefetchSourceType pfSource) const { + trigger_info.pfSourceType = pfSource; + } }; + /** + * Class containing the information needed by the prefetch to train and + * generate new prefetch requests. this is only used by PFtriggerInfo + */ + class PrefetchInfo_old + { + friend class PrefetchInfo; + /** The address used to train and generate prefetches */ + Addr address; + /** The program counter that generated this address. */ + Addr pc; + /** The requestor ID that generated this address. */ + RequestorID requestorId; + /** Validity bit for the PC of this address. */ + bool validPC; + /** Whether this address targets the secure memory space. */ + bool secure; + /** Size in bytes of the request triggering this event */ + unsigned int size; + /** Whether this event comes from a write request */ + bool write; + /** Physical address, needed because address can be virtual */ + Addr paddress; + /** Whether this event comes from a cache miss */ + bool cacheMiss; + /** Pointer to the associated request data */ + uint8_t *data; + /** XiangShan metadata of the block*/ + Request::XsMetadata xsMetadata; + + bool reqAfterSquash{false}; + + bool everPrefetched{false}; + + bool pfFirstHit{false}; + + bool pfHit{false}; + + bool storePFTrain{ false }; + + uint64_t *data_ptr; + + public: + uint64_t * getDataPtr()const{ + return data_ptr; + } + /** + * Obtains the address value of this Prefetcher address. + * @return the addres value. + */ + Addr getAddr() const + { + return address; + } + + /** + * Returns true if the address targets the secure memory space. + * @return true if the address targets the secure memory space. + */ + bool isSecure() const + { + return secure; + } + /** + * Returns the program counter that generated this request. + * @return the pc value + */ + Addr getPC() const + { + assert(hasPC()); + return pc; + } + + /** + * Returns true if the associated program counter is valid + * @return true if the program counter has a valid value + */ + bool hasPC() const + { + return validPC; + } + + /** + * Gets the requestor ID that generated this address + * @return the requestor ID that generated this address + */ + RequestorID getRequestorId() const + { + return requestorId; + } + + /** + * Gets the size of the request triggering this event + * @return the size in bytes of the request triggering this event + */ + unsigned int getSize() const + { + return size; + } + + /** + * Checks if the request that caused this prefetch event was a write + * request come from committed store inst + * @return true if the request causing this event is a write request + */ + bool isWrite() const + { + return write; + } + + // is come from store prefetch train trigger + bool isStore() const + { + return storePFTrain; + } + + /** + * Gets the physical address of the request + * @return physical address of the request + */ + Addr getPaddr() const + { + return paddress; + } + + /** + * Check if this event comes from a cache miss + * @result true if this event comes from a cache miss + */ + bool isCacheMiss() const + { + return cacheMiss; + } + + /** + * Gets the associated data of the request triggering the event + * @param Byte ordering of the stored data + * @return the data + */ + template + inline T + get(ByteOrder endian) const + { + if (data == nullptr) { + panic("PrefetchInfo::get called with a request with no data."); + } + switch (endian) { + case ByteOrder::big: + return betoh(*(T*)data); + + case ByteOrder::little: + return letoh(*(T*)data); + + default: + panic("Illegal byte order in PrefetchInfo::get()\n"); + }; + } + + /** + * Check for equality + * @param pfi PrefetchInfo to compare against + * @return True if this object and the provided one are equal + */ + bool sameAddr(PrefetchInfo_old const &pfi) const + { + return this->getAddr() == pfi.getAddr() && + this->isSecure() == pfi.isSecure(); + } + + bool sameAddr(Addr addr, bool isSecure) const + { + return this->getAddr() == addr && + this->isSecure() == isSecure; + } + + Request::XsMetadata getXsMetadata() const + { + return xsMetadata; + } + + void setXsMetadata(const Request::XsMetadata &xs_metadata) + { + this->xsMetadata = xs_metadata; + } + + bool isReqAfterSquash() const + { + return reqAfterSquash; + } + + void setReqAfterSquash(bool req_after_squash) + { + reqAfterSquash = req_after_squash; + } + + bool isEverPrefetched() const { return everPrefetched; } + + void setEverPrefetched(bool prefetched) { everPrefetched = prefetched; } + + bool isPfHit() const { return pfHit; } + + void setPfHit(bool hit) { pfHit = hit; } + + bool isPfFirstHit() const { return pfFirstHit; } + + void setPfFirstHit(bool hit) { pfFirstHit = hit; } + + void setStorePftrain(bool s) { storePFTrain = s; } + + /** + * Constructs a PrefetchInfo using a PacketPtr. + * @param pkt PacketPtr used to generate the PrefetchInfo + * @param addr the address value of the new object, this address is + * used to train the prefetcher + * @param miss whether this event comes from a cache miss + */ + PrefetchInfo_old(PacketPtr pkt, Addr addr, bool miss); + + PrefetchInfo_old(PacketPtr pkt, Addr addr, bool miss, Request::XsMetadata xsMeta); + + /** + * Constructs a PrefetchInfo using a new address value and + * another PrefetchInfo as a reference. + * @param pfi PrefetchInfo used to generate this new object + * @param addr the address value of the new object + */ + PrefetchInfo_old(PrefetchInfo_old const &pfi, Addr addr); + + PrefetchInfo_old(PrefetchInfo_old const &other); + + PrefetchInfo_old(PrefetchInfo const &pfi); + + ~PrefetchInfo_old() + { + delete[] data; + } + + bool lastPfLate{false}; + }; protected: + /** + * TrainFilter: ROB-order training request filtering and reordering + * + * The TrainFilter collects training requests within a cycle, reorders them + * by ROB sequence number (Load first, then Store), filters duplicates, and + * feeds them into a FIFO training buffer at a rate of one per cycle. + */ + struct TrainingRequest + { + RequestPtr req; + MemCmd cmd; + PacketDataPtr dataCopy; // Deep copy of packet data + unsigned dataSize; // Size of data copied + + Addr addr; // Training address + bool miss; + Request::XsMetadata xsMetadata; + bool everPrefetched; + bool pfFirstHit; + bool pfHit; + bool squashMark; // Request after squash + + // TrainFilter fields for ROB-order training + InstSeqNum seqNum; // ROB sequence number for ordering + Addr blockAddr; // Cache block address for filtering + bool isLoad; + + // Constructor: Extract copies from PacketPtr + TrainingRequest(PacketPtr pkt, Addr _addr, bool _miss, + const Request::XsMetadata &_xsMetadata, + bool _everPrefetched, bool _pfFirstHit, + bool _pfHit, bool _squashMark, + InstSeqNum _seqNum, Addr _blockAddr, bool _isLoad) + : req(pkt->req), + cmd(pkt->cmd), + dataCopy(nullptr), + dataSize(pkt->getSize()), + addr(_addr), + miss(_miss), + xsMetadata(_xsMetadata), + everPrefetched(_everPrefetched), + pfFirstHit(_pfFirstHit), + pfHit(_pfHit), + squashMark(_squashMark), + seqNum(_seqNum), + blockAddr(_blockAddr), + isLoad(_isLoad) + { + // Deep copy packet data if present + if (pkt->flags.isSet(Packet::STATIC_DATA | Packet::DYNAMIC_DATA)) { + dataCopy = new uint8_t[dataSize]; + std::memcpy(dataCopy, pkt->getConstPtr(), dataSize); + } + } + + // Destructor: Free our owned data copy + ~TrainingRequest() { + if (dataCopy) { + delete[] dataCopy; + dataCopy = nullptr; + } + } + + // Move constructor: Transfer ownership of dataCopy + TrainingRequest(TrainingRequest&& other) noexcept + : req(std::move(other.req)), + cmd(other.cmd), + dataCopy(other.dataCopy), + dataSize(other.dataSize), + addr(other.addr), + miss(other.miss), + xsMetadata(other.xsMetadata), + everPrefetched(other.everPrefetched), + pfFirstHit(other.pfFirstHit), + pfHit(other.pfHit), + squashMark(other.squashMark), + seqNum(other.seqNum), + blockAddr(other.blockAddr), + isLoad(other.isLoad) + { + other.dataCopy = nullptr; // Transfer ownership + } + + // Move assignment: Transfer ownership of dataCopy + TrainingRequest& operator=(TrainingRequest&& other) noexcept { + if (this != &other) { + // Free our old data + if (dataCopy) delete[] dataCopy; + + // Transfer everything from other + req = std::move(other.req); + cmd = other.cmd; + dataCopy = other.dataCopy; + dataSize = other.dataSize; + addr = other.addr; + miss = other.miss; + xsMetadata = other.xsMetadata; + everPrefetched = other.everPrefetched; + pfFirstHit = other.pfFirstHit; + pfHit = other.pfHit; + squashMark = other.squashMark; + seqNum = other.seqNum; + blockAddr = other.blockAddr; + isLoad = other.isLoad; + + other.dataCopy = nullptr; + } + return *this; + } + + // Disable copy constructor/assignment + TrainingRequest(const TrainingRequest&) = delete; + TrainingRequest& operator=(const TrainingRequest&) = delete; + }; + + std::vector currentCycleLoads; + std::vector currentCycleStores; + std::deque trainingBuffer; + + std::unordered_set trainingBufferBlockAddrs; + + /** Maximum size of the training buffer */ + const unsigned trainingBufferSize; + + + /** + * Periodic event that fires every cycle + * Handles: 1) flush previous cycle requests, 2) train one request + * This ensures training progresses even when there are no cache accesses + */ + EventFunctionWrapper cycleEvent; + + void processCycle(); + + void flushCurrentCycleRequests(); + + /** + * Train one request from the front of trainingBuffer + * Dequeues one request and calls notify() to train the prefetcher + */ + void processTraining(); + + /** Whether to use training buffer (can be overridden by subclasses) */ + virtual bool useTrainingBuffer() const { return false; } + + Addr getBlockAddr(Addr addr) const { return blockAddress(addr); } + + InstSeqNum getSeqNum(const PacketPtr &pkt) const; + + bool isLoadRequest(const PacketPtr &pkt) const; bool isSubPrefetcher; @@ -379,6 +827,9 @@ class Base : public ClockedObject const Addr pageBytes; + /** Allow upstream PF req train low level Prefetcher */ + const bool prefetchTrain; + /** Prefetch on every access, not just misses */ const bool prefetchOnAccess; @@ -455,6 +906,12 @@ class Base : public ClockedObject * in the Write Buffer (WB). */ statistics::Scalar pfHitInWB; + /** The number of prefetch requests generated by prefetcher. */ + statistics::Scalar pfGenerated; + + /** The number of prefetch requests filtered before issuing. */ + statistics::Scalar pfFiltered; + /** The number of times a HW-prefetch is late * (hit in cache, MSHR, WB). */ statistics::Formula pfLate; diff --git a/src/mem/cache/prefetch/berti.cc b/src/mem/cache/prefetch/berti.cc index 4e2a8492b2..e555eb7144 100644 --- a/src/mem/cache/prefetch/berti.cc +++ b/src/mem/cache/prefetch/berti.cc @@ -275,6 +275,9 @@ BertiPrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vecto if (using_best_delta_and_confident) { lastUsedBestDelta = blockIndex(addr) - blockIndex(pfi.getAddr()); } + // buffered prefetch + InsertPFRequestToBuffer(AddrPriority(addr, prio, src, pfi.trigger_info)); + if (filter->contains(addr)) { DPRINTF(BertiPrefetcher, "Skip recently prefetched: %lx\n", addr); return false; diff --git a/src/mem/cache/prefetch/bop.cc b/src/mem/cache/prefetch/bop.cc index 10ad1bb11e..7220d74e53 100644 --- a/src/mem/cache/prefetch/bop.cc +++ b/src/mem/cache/prefetch/bop.cc @@ -127,9 +127,30 @@ BOP::delayQueueEventWrapper() unsigned int BOP::hash(Addr addr, unsigned int way) const { - Addr hash1 = addr; - Addr hash2 = hash1 >> floorLog2(rrEntries); - return (hash1 ^ hash2) & (Addr)(rrEntries - 1); + // NOTE: This unit-test BOP is used to replay XiangShan-generated traces. + // Align RR indexing with XiangShan Chisel (BestOffsetPrefetch.scala): + // lineAddr = addr >> offsetBits + // hash1 = lineAddr[rrIdxBits-1:0] + // hash2 = lineAddr[2*rrIdxBits-1:rrIdxBits] + // idx = hash1 ^ hash2 + // + // The original gem5 BOP implementation used two banks (Left/Right) with + // different hashing. XiangShan uses a single direct-mapped RR, so 'way' + // is ignored here. + // + // Original gem5 BOP (indexed using the *tag* value, not full addr): + // Addr hash1 = tag >> way; + // Addr hash2 = hash1 >> floorLog2(rrEntries); + // idx = (hash1 ^ hash2) & (rrEntries - 1); + (void)way; + + const unsigned rrIdxBits = floorLog2(rrEntries); + const unsigned offsetBits = floorLog2(blkSize); + const Addr line_addr = addr >> offsetBits; + const Addr mask = static_cast(rrEntries - 1); + const Addr hash1 = line_addr & mask; + const Addr hash2 = (line_addr >> rrIdxBits) & mask; + return static_cast((hash1 ^ hash2) & mask); } void @@ -143,10 +164,10 @@ BOP::insertIntoRR(RREntryDebug rr_entry, unsigned int way) { switch (way) { case RRWay::Left: - rrLeft[hash(rr_entry.hashAddr, RRWay::Left)] = rr_entry; + rrLeft[hash(rr_entry.fullAddr, RRWay::Left)] = rr_entry; break; case RRWay::Right: - rrRight[hash(rr_entry.hashAddr, RRWay::Right)] = rr_entry; + rrRight[hash(rr_entry.fullAddr, RRWay::Right)] = rr_entry; break; } } @@ -180,17 +201,30 @@ BOP::resetScores() inline Addr BOP::tag(Addr addr) const { - return (addr >> lBlkSize) & tagMask; + // Align tag extraction with XiangShan Chisel (BestOffsetPrefetch.scala): + // tag = lineAddr[rrIdxBits+rrTagBits-1:rrIdxBits] + // where lineAddr = addr >> offsetBits. + // + // Original gem5 BOP (commented) used: + // (addr >> offsetBits) & tagMask + // which kept the lowest tagBits of the line address. + const unsigned rrIdxBits = floorLog2(rrEntries); + const unsigned offsetBits = floorLog2(blkSize); + const Addr line_addr = addr >> offsetBits; + return (line_addr >> rrIdxBits) & tagMask; } std::pair -BOP::testRR(Addr tag) const +BOP::testRR(Addr addr) const { - if (rrLeft[hash(tag, RRWay::Left)].hashAddr == tag) { - return std::make_pair(true, rrLeft[hash(tag, RRWay::Left)]); + const Addr t = tag(addr); + const unsigned idx_l = hash(addr, RRWay::Left); + if (rrLeft[idx_l].hashAddr == t) { + return std::make_pair(true, rrLeft[idx_l]); } - if (rrRight[hash(tag, RRWay::Right)].hashAddr == tag) { - return std::make_pair(true, rrRight[hash(tag, RRWay::Right)]); + const unsigned idx_r = hash(addr, RRWay::Right); + if (rrRight[idx_r].hashAddr == t) { + return std::make_pair(true, rrRight[idx_r]); } return std::make_pair(false, RREntryDebug()); @@ -396,11 +430,12 @@ BOP::bestOffsetLearning(Addr x, bool late, const PrefetchInfo &pfi) resetScores(); //issuePrefetchRequests = true; return true; - } else if ((round >= roundMax/2) && (bestOffset != phaseBestOffset) && (bestScore <= badScore)) { - DPRINTF(BOPPrefetcher, "last round offset has not enough confidence, early stop\n"); - DPRINTF(BOPPrefetcher, "score %u < badScore %u\n", bestScore, badScore); - issuePrefetchRequests = false; - } + } // here temporarily disable early stop, to align with RTL + // else if ((round >= roundMax/2) && (bestOffset != phaseBestOffset) && (bestScore <= badScore)) { + // DPRINTF(BOPPrefetcher, "last round offset has not enough confidence, early stop\n"); + // DPRINTF(BOPPrefetcher, "score %u < badScore %u\n", bestScore, badScore); + // issuePrefetchRequests = false; + // } } DPRINTF(BOPPrefetcher, "Reach %s end, iter offset: %d\n", __FUNCTION__, offsetsListIterator->calcOffset()); return false; @@ -453,14 +488,22 @@ bool BOP::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, PrefetchSourceType src) { + // Count generated prefetch + prefetchStats.pfGenerated++; + if (!samePage(pfi.getAddr(), addr) && !crossPage) { + // Count filtered prefetch (cross-page) + prefetchStats.pfFiltered++; return false; } if (archDBer && cache->level() == 1) { archDBer->l1PFTraceWrite(curTick(), pfi.getPC(), pfi.getAddr(), addr, src); } + InsertPFRequestToBuffer(AddrPriority(addr, prio, src, pfi.trigger_info)); if (filter->contains(addr)) { DPRINTF(BOPPrefetcher, "Skip recently prefetched: %lx\n", addr); + // Count filtered prefetch + prefetchStats.pfFiltered++; return false; } else { DPRINTF(BOPPrefetcher, "Send pf: %lx\n", addr); diff --git a/src/mem/cache/prefetch/cdp.cc b/src/mem/cache/prefetch/cdp.cc index ffde6ed7f2..525a2fcde6 100644 --- a/src/mem/cache/prefetch/cdp.cc +++ b/src/mem/cache/prefetch/cdp.cc @@ -169,13 +169,13 @@ CDP::calculatePrefetch(const PrefetchInfo &pfi, std::vector &addre vpn2 = BITS(pt_addr, 38, 30); vpn1 = BITS(pt_addr, 29, 21); vpnTable.update(vpn2, vpn1, enable_thro, isLowConfidence()); - sendPFWithFilter(blockAddress(pt_addr), addresses, 30, PrefetchSourceType::CDP, 1); + sendPFWithFilter(pfi, blockAddress(pt_addr), addresses, 30, PrefetchSourceType::CDP, 1); for (int i = 1; i < degree; i++) { if (getCdpTrueAccuracy() > 0.05) { Addr next_pf_addr = blockAddress(pt_addr) + (i * 0x40); vpnTable.update(BITS(next_pf_addr, 38, 30), BITS(next_pf_addr, 29, 21), enable_thro, isLowConfidence()); - sendPFWithFilter(next_pf_addr, addresses, 1, PrefetchSourceType::CDP, 1); + sendPFWithFilter(pfi, next_pf_addr, addresses, 1, PrefetchSourceType::CDP, 1); } } cdpStats.triggeredInCalcPf++; @@ -302,14 +302,14 @@ CDP::notifyWithData(const PacketPtr &pkt, bool is_l1_use, std::vector 0.05) { Addr next_pf_addr = blockAddress(test_addr2) + (i * 0x40); vpnTable.update(BITS(next_pf_addr, 38, 30), BITS(next_pf_addr, 29, 21), enable_thro, isLowConfidence()); - sendPFWithFilter(next_pf_addr, addresses, 1, PrefetchSourceType::CDP, + sendPFWithFilter(pkt, next_pf_addr, addresses, 1, PrefetchSourceType::CDP, next_depth); } } @@ -344,9 +344,32 @@ CDP::pfHitNotify(float accuracy, PrefetchSourceType pf_source, const PacketPtr & } bool -CDP::sendPFWithFilter(Addr addr, std::vector &addresses, int prio, PrefetchSourceType pfSource, - int pf_depth) +CDP::sendPFWithFilter(const PacketPtr &pkt, Addr addr, std::vector &addresses, + int prio, PrefetchSourceType pfSource, int pf_depth) { + //fake a PrefetchInfo, thus this reill pf will use Queued::PFSendEventWrapper to send out pf req + PrefetchInfo pfi(pkt, pkt->req->getVaddr(), false); + pfi.setTriggerInfo(pkt); + + InsertPFRequestToBuffer(AddrPriority(addr, prio, pfSource, pfi.trigger_info)); + if (pfLRUFilter->contains((addr))) { + return false; + } else { + pfLRUFilter->insert((addr), 0); + AddrPriority addr_prio = AddrPriority(addr, prio, pfSource); + addr_prio.depth = pf_depth; + addresses.push_back(addr_prio); + cdpStats.passedFilter++; + return true; + } + return false; +} + +bool +CDP::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, + int prio, PrefetchSourceType pfSource, int pf_depth) +{ + InsertPFRequestToBuffer(AddrPriority(addr, prio, pfSource, pfi.trigger_info)); if (pfLRUFilter->contains((addr))) { return false; } else { diff --git a/src/mem/cache/prefetch/cdp.hh b/src/mem/cache/prefetch/cdp.hh index 58cf707897..6f309ad989 100644 --- a/src/mem/cache/prefetch/cdp.hh +++ b/src/mem/cache/prefetch/cdp.hh @@ -414,8 +414,10 @@ class CDP : public Queued { rivalCoverage = info.coverage; } - bool sendPFWithFilter(Addr addr, std::vector &addresses, int prio, PrefetchSourceType pfSource, - int pf_depth); + bool sendPFWithFilter(const PacketPtr &pkt, Addr addr, std::vector &addresses, + int prio, PrefetchSourceType pfSource, int pf_depth); + bool sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, + int prio, PrefetchSourceType pfSource, int pf_depth); CDP(const CDPParams &p); diff --git a/src/mem/cache/prefetch/cmc.cc b/src/mem/cache/prefetch/cmc.cc index 331501b42c..162293842d 100644 --- a/src/mem/cache/prefetch/cmc.cc +++ b/src/mem/cache/prefetch/cmc.cc @@ -1,5 +1,7 @@ #include "mem/cache/prefetch/cmc.hh" +#include + #include "base/output.hh" #include "debug/CMCPrefetcher.hh" #include "mem/cache/prefetch/associative_set_impl.hh" @@ -70,6 +72,8 @@ CMCPrefetcher::CMCPrefetcher(const CMCPrefetcherParams &p) db.save_db(simout.resolve("cmc.db").c_str()); } }); + sendingEntry.invalidate(); + sendIDX_PTR = 0; } void @@ -132,7 +136,13 @@ CMCPrefetcher::doPrefetch(const PrefetchInfo &pfi, std::vector &ad match_entry->refcnt++; int priority = recorder->nr_entry; uint32_t id = match_entry->id; - + //create a copy , insert to tpdataqueue + StorageEntry entry_copy = StorageEntry(*match_entry); + entry_copy.trigger = std::make_unique(pfi.trigger_info); + if( tpDataQueue.size() >= maxTpDataQueueSize){ + tpDataQueue.pop_front(); + } + tpDataQueue.push_back(entry_copy); int num_send = 0; for (auto addr: match_entry->addresses) { // addresses.push_back(AddrPriority(addr, mixedNum, PrefetchSourceType::CMC)); @@ -275,8 +285,13 @@ bool CMCPrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, PrefetchSourceType src) { + // Count generated prefetch + prefetchStats.pfGenerated++; + if (filter->contains(addr)) { DPRINTF(CMCPrefetcher, "Skip recently prefetched: %lx\n", addr); + // Count filtered prefetch + prefetchStats.pfFiltered++; return false; } else { DPRINTF(CMCPrefetcher, "CMC: send pf: %lx\n", addr); @@ -302,6 +317,66 @@ CMCPrefetcher::StorageEntry::invalidate() { } TaggedEntry::invalidate(); } +void +CMCPrefetcher::InsertPFRequestToBuffer(const AddrPriority &addr_prio) { + panic("CMCPrefetcher: InsertPFRequestToBuffer not implemented"); +} +bool +CMCPrefetcher::hasPFRequestsInBuffer() { + return !tpDataQueue.empty() || sendingEntry.isValid(); +} +bool +CMCPrefetcher::GetPFRequestsFromBuffer(std::vector &addresses) { + //if sendingEntry is valid, send next addr + if(sendingEntry.isValid()){ + if(sendIDX_PTR < sendingEntry.addresses.size()){ + Addr addr = sendingEntry.addresses[sendIDX_PTR]; + sendIDX_PTR++; + if (sendingEntry.trigger) { + addresses.push_back(AddrPriority(addr, + recorder->nr_entry - sendIDX_PTR + 1, + PrefetchSourceType::CMC, + *(sendingEntry.trigger))); + } else { + addresses.push_back(AddrPriority(addr, + recorder->nr_entry - sendIDX_PTR + 1, + PrefetchSourceType::CMC)); + } + return true; + }else{ + //finished sending this entry + sendingEntry = StorageEntry(); + sendingEntry.invalidate(); + sendIDX_PTR = 0; + } + } + //load next entry from tpDataQueue + if(!tpDataQueue.empty()){ + //copy front entry to sendingEntry + sendingEntry = StorageEntry(tpDataQueue.front()); + tpDataQueue.pop_front(); + sendIDX_PTR = 0; + if(sendIDX_PTR < sendingEntry.addresses.size()){ + Addr addr = sendingEntry.addresses[sendIDX_PTR]; + sendIDX_PTR++; + if (sendingEntry.trigger) { + addresses.push_back(AddrPriority(addr, + recorder->nr_entry - sendIDX_PTR + 1, + PrefetchSourceType::CMC, + *(sendingEntry.trigger))); + } else { + addresses.push_back(AddrPriority(addr, + recorder->nr_entry - sendIDX_PTR + 1, + PrefetchSourceType::CMC)); + } + return true; + }else{ + //should not happen + panic("CMCPrefetcher: empty addresses in sendingEntry"); + } + } + return false; +} } // prefetch } // gem5 diff --git a/src/mem/cache/prefetch/cmc.hh b/src/mem/cache/prefetch/cmc.hh index 03d41530bb..d13294cd3f 100644 --- a/src/mem/cache/prefetch/cmc.hh +++ b/src/mem/cache/prefetch/cmc.hh @@ -3,6 +3,9 @@ #include #include +#include +#include +#include #include "base/types.hh" #include "cpu/pred/general_arch_db.hh" @@ -25,6 +28,7 @@ namespace prefetch class CMCPrefetcher : public Queued { public: + using TriggerInfo = PFTriggerInfo; class StorageEntry; class RecordEntry { @@ -48,7 +52,7 @@ class CMCPrefetcher : public Queued bool train_entry(Addr, bool, bool*); void reset(); - const int nr_entry = 16; + const int nr_entry = 12; private: }; @@ -59,6 +63,41 @@ class CMCPrefetcher : public Queued int refcnt; uint64_t id; void invalidate() override; + std::unique_ptr trigger; + StorageEntry() : addresses(), refcnt(0), id(0), trigger(nullptr) {} + + // copy constructor + StorageEntry(const StorageEntry &other) + : TaggedEntry(other), + addresses(other.addresses), + refcnt(other.refcnt), + id(other.id) + { + if (other.trigger) { + trigger = std::make_unique(*(other.trigger)); + } + } + + // copy assignment + StorageEntry& operator=(const StorageEntry &other) + { + if (this != &other) { + TaggedEntry::operator=(other); + addresses = other.addresses; + refcnt = other.refcnt; + id = other.id; + if (other.trigger) { + trigger = std::make_unique(*(other.trigger)); + } else { + trigger.reset(); + } + } + return *this; + } + + StorageEntry(StorageEntry &&) noexcept = default; + StorageEntry& operator=(StorageEntry &&) noexcept = default; + ~StorageEntry() = default; }; private: Recorder *recorder; @@ -95,6 +134,15 @@ class CMCPrefetcher : public Queued static const int STACK_SIZE = 4; boost::circular_buffer trigger; // RecordEntry trigger_stack[STACK_SIZE]; + protected: + std::list tpDataQueue; + const int maxTpDataQueueSize = 8; + StorageEntry sendingEntry; + int sendIDX_PTR = 0;// point to the next idx of sendingEntry + void InsertPFRequestToBuffer(const AddrPriority &addr_prio) override; + public: + bool GetPFRequestsFromBuffer(std::vector &addresses) override; + bool hasPFRequestsInBuffer() override; }; struct TriggerTrace : public Record diff --git a/src/mem/cache/prefetch/despacito_stream.cc b/src/mem/cache/prefetch/despacito_stream.cc index 12e5f292bf..46f5006a45 100644 --- a/src/mem/cache/prefetch/despacito_stream.cc +++ b/src/mem/cache/prefetch/despacito_stream.cc @@ -104,6 +104,11 @@ DespacitoStreamPrefetcher::updatePatternTable(SamplerEntry *sampler_entry) void DespacitoStreamPrefetcher::calculatePrefetch(const PrefetchInfo &pfi, std::vector &addresses, bool late) { + if (archDBer){ + archDBer->despacitoTraceWrite(curTick(), pfi.getAddr(), pfi.getPaddr(), pfi.hasPC() ? pfi.getPC() : 0, + pfi.hasPC(), pfi.isCacheMiss(), true); + } + if (!pfi.hasPC()) { return; } @@ -146,6 +151,11 @@ DespacitoStreamPrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, if (archDBer && cache->level() == 1) { archDBer->l1PFTraceWrite(curTick(), pfi.getPC(), pfi.getAddr(), addr, src); } + if (archDBer){ + archDBer->despacitoTraceWrite(curTick(), addr, 0, pfi.hasPC() ? pfi.getPC() : 0, + pfi.hasPC(), pfi.isCacheMiss(), false); + } + InsertPFRequestToBuffer(AddrPriority(addr, prio, src, pfi.trigger_info)); if (filter->contains(addr)) { DPRINTF(DespacitoStreamPrefetcher, "Skip recently prefetched: %lx\n", addr); return false; diff --git a/src/mem/cache/prefetch/ipcp.cc b/src/mem/cache/prefetch/ipcp.cc index 0901da58c6..edc7e2ba57 100644 --- a/src/mem/cache/prefetch/ipcp.cc +++ b/src/mem/cache/prefetch/ipcp.cc @@ -57,6 +57,7 @@ IPCP::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vectorcontains(addr)) { DPRINTF(IPCP, "IPCP PF filtered\n"); ipcpStats.pf_filtered++; diff --git a/src/mem/cache/prefetch/l2_composite_with_worker.cc b/src/mem/cache/prefetch/l2_composite_with_worker.cc index b7ede79a17..b0b70f8d8a 100644 --- a/src/mem/cache/prefetch/l2_composite_with_worker.cc +++ b/src/mem/cache/prefetch/l2_composite_with_worker.cc @@ -144,6 +144,46 @@ L2CompositeWithWorkerPrefetcher::notifyFill(const PacketPtr &pkt) } addressGenBuffer.clear(); } - +bool +L2CompositeWithWorkerPrefetcher::GetPFRequestsFromBuffer(std::vector &addresses) +{ + //here we decide which to send for this cycle + //L1 Streamstride>berti>SMS>CMC + //L2 Streamstride>SMS>vBOP>pbop>TP + if(pfq.size() == queueSize) { + return false; + } + bool L2PFsent = false; + L2PFsent = ticksToCycles(latestTransferTick) == ticksToCycles(curTick()); + if (!L2PFsent && largeBOP->hasPFRequestsInBuffer()){ + L2PFsent = largeBOP->GetPFRequestsFromBuffer(addresses); + } + if (!L2PFsent && smallBOP->hasPFRequestsInBuffer()){ + L2PFsent = smallBOP->GetPFRequestsFromBuffer(addresses); + } + if (!L2PFsent && despacitoStream->hasPFRequestsInBuffer()){ + L2PFsent = despacitoStream->GetPFRequestsFromBuffer(addresses); + } + if (!L2PFsent && cdp->hasPFRequestsInBuffer()){ + L2PFsent = cdp->GetPFRequestsFromBuffer(addresses); + } + if (!L2PFsent && cmc->hasPFRequestsInBuffer()){ + L2PFsent = cmc->GetPFRequestsFromBuffer(addresses); + } + // For now we dont have L3PF + // bool L3PFsent = false; + // L3PFsent = stridestream_pfFilter_l2l3.GetPFAddrL3(addresses); + // if (!L3PFsent){ + // L3PFsent = sms_pfFilter.GetPFAddrL3(addresses); + // } + return L2PFsent; +} +bool L2CompositeWithWorkerPrefetcher::hasPFRequestsInBuffer() { + return largeBOP->hasPFRequestsInBuffer() || + smallBOP->hasPFRequestsInBuffer() || + cmc->hasPFRequestsInBuffer() || + cdp->hasPFRequestsInBuffer() || + despacitoStream->hasPFRequestsInBuffer(); + } } // namespace prefetch } // namespace gem5 diff --git a/src/mem/cache/prefetch/l2_composite_with_worker.hh b/src/mem/cache/prefetch/l2_composite_with_worker.hh index e6e450b7c5..0b6cf980ac 100644 --- a/src/mem/cache/prefetch/l2_composite_with_worker.hh +++ b/src/mem/cache/prefetch/l2_composite_with_worker.hh @@ -61,6 +61,13 @@ class L2CompositeWithWorkerPrefetcher : public CompositeWithWorkerPrefetcher const bool enableDespacitoStream; bool offloadLowAccuracy = true; + protected: + void InsertPFRequestToBuffer(const AddrPriority &addr_prio) override{ + panic("SMS:InsertPFRequestToBuffer not implemented"); + }; + public: + bool GetPFRequestsFromBuffer(std::vector &addresses) override; + bool hasPFRequestsInBuffer() override; }; } // namespace prefetch diff --git a/src/mem/cache/prefetch/opt.cc b/src/mem/cache/prefetch/opt.cc index d980158595..952d0bffbb 100644 --- a/src/mem/cache/prefetch/opt.cc +++ b/src/mem/cache/prefetch/opt.cc @@ -218,12 +218,24 @@ bool OptPrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, PrefetchSourceType src, int ahead_level) { + //buffered prefetch + AddrPriority pf_req = AddrPriority(addr, prio, src, pfi.trigger_info); + if (ahead_level > 1) { + assert(ahead_level == 2 || ahead_level == 3); + pf_req.pfahead_host = ahead_level; + pf_req.pfahead = true; + } else { + pf_req.pfahead = false; + } + InsertPFRequestToBuffer(pf_req); + // origin prefetch if (filter->contains(addr)) { DPRINTF(OptPrefetcher, "Skip recently prefetched: %lx\n", addr); return false; } else { DPRINTF(OptPrefetcher, "Send pf: %lx\n", addr); filter->insert(addr, 0); + addresses.push_back(AddrPriority(addr, prio, src)); if (ahead_level > 1) { assert(ahead_level == 2 || ahead_level == 3); addresses.back().pfahead_host = ahead_level; diff --git a/src/mem/cache/prefetch/prefetch_filter.cc b/src/mem/cache/prefetch/prefetch_filter.cc new file mode 100644 index 0000000000..2efb8f4699 --- /dev/null +++ b/src/mem/cache/prefetch/prefetch_filter.cc @@ -0,0 +1,497 @@ +#include "mem/cache/prefetch/prefetch_filter.hh" +#include +#include +#include +#include +#include + +#include "base/stats/group.hh" +#include "debug/HWPrefetch.hh" +#include "mem/cache/prefetch/associative_set_impl.hh" + +namespace gem5 { +namespace prefetch { + +PrefetchFilter::Stats::Stats(statistics::Group *parent, const std::string &name) + : statistics::Group(parent, name.c_str()), + ADD_STAT(insertCount, statistics::units::Count::get(), "PrefetchFilter insert count"), + ADD_STAT(queryHitCount, statistics::units::Count::get(), "PrefetchFilter query hit count"), + ADD_STAT(prefetchIssued, statistics::units::Count::get(), "Prefetches issued by PrefetchFilter"), + ADD_STAT(replacementCount, statistics::units::Count::get(), "PrefetchFilter replacement count"), + ADD_STAT(l1Calls, statistics::units::Count::get(), "GetPFAddrL1 calls"), + ADD_STAT(l1Issued, statistics::units::Count::get(), "GetPFAddrL1 issued"), + ADD_STAT(l2Calls, statistics::units::Count::get(), "GetPFAddrL2 calls"), + ADD_STAT(l2Issued, statistics::units::Count::get(), "GetPFAddrL2 issued"), + ADD_STAT(l3Calls, statistics::units::Count::get(), "GetPFAddrL3 calls"), + ADD_STAT(l3Issued, statistics::units::Count::get(), "GetPFAddrL3 issued"), + ADD_STAT(hashcollisionCount, statistics::units::Count::get(), "PrefetchFilter hash collision count") +{ + +} + +PrefetchFilter::PrefetchFilter(gem5::BaseIndexingPolicy *idx_policy, + gem5::replacement_policy::Base *rpl_policy, + unsigned entries, unsigned region_size, + unsigned blk_size, statistics::Group *parent, + unsigned vaddr_hash_width, + PrefetchSourceType pf_source_type, + const std::string &name) + : table(entries, entries, idx_policy,rpl_policy, Entry()), + regionSize(region_size), + blkSize(blk_size), + regionBlks(region_size / blk_size), + rrIndex(0), + REGION_ADDR_RAW_WIDTH(6),//align with rtl + vaddrHashWidth(vaddr_hash_width), + stats(parent, name), + pfSourceType(pf_source_type), + table_name(name) +{ +} + +PrefetchFilter::~PrefetchFilter() = default; + +// PrefetchFilter::Entry* +// PrefetchFilter::findByVaddr(Addr vaddr, bool is_secure) +// { +// Addr region = vaddr / regionSize; +// Entry *e = table.findEntry(region, is_secure); +// if (e) { +// stats.queryHitCount++; +// } +// return e; +// } + +// PrefetchFilter::Entry* +// PrefetchFilter::findByRegion(Addr region, bool is_secure) +// { +// Entry *e = table.findEntry(region, is_secure); +// if (e) { +// stats.queryHitCount++; +// } +// return e; +// } + +// PrefetchFilter::Entry* +// PrefetchFilter::allocateForVaddr(Addr vaddr, bool is_secure, Addr region_addr) +// { +// Addr region = vaddr / regionSize; +// Entry *victim = table.findVictim(region); + +// victim->region_addr = region_addr ? region_addr : region; +// victim->region_bits = 0; +// victim->filter_bits = 0; +// victim->alias_bits = aliasFromVaddr(vaddr); +// victim->paddr_valid = (region_addr != 0); +// victim->decr_mode = false; +// victim->_setSecure(is_secure); + +// table.insertEntry(region, is_secure, victim); +// return victim; +// } + +bool +PrefetchFilter::GetPFAddrL1(std::vector &addresses) +{ + stats.l1Calls++; + auto it_begin = table.begin(); + auto it_end = table.end(); + const auto n = std::distance(it_begin, it_end); + DPRINTF(HWPrefetch, "GetPFAddrL1 called. table size: %lu\n", static_cast(n)); + if (n == 0) + return false; + + uint64_t mask = (regionBlks >= 64) ? ~uint64_t(0) : ((uint64_t(1) << regionBlks) - 1); + + for (unsigned i = 0; i < static_cast(n); ++i) { + unsigned idx = (rrIndex + i) % static_cast(n); + auto it = it_begin; + std::advance(it, idx); + Entry *e = &(*it); + + if (!e->isValid()) + continue; + if (!e->paddr_valid) + continue; + if (e->PFlevel != 1) + continue; + uint64_t pending = e->region_bits & (~e->filter_bits) & mask; + if (!pending) + continue; + + unsigned region_offset = 0; + if (e->decr_mode) { + unsigned lz = __builtin_clzll(pending); + unsigned msb = 63 - lz; + region_offset = msb; + } else { + region_offset = __builtin_ctzll(pending); + } + + Addr region_num = e->region_addr; + // Use bit operations to compute: region_num * regionSize + region_offset * blkSize + // Assume regionSize and blkSize are powers of two; compute shift amounts. + unsigned rs_shift = __builtin_ctz(regionSize); + unsigned bs_shift = __builtin_ctz(blkSize); + Addr pf_addr = (region_num << rs_shift) + (Addr(region_offset) << bs_shift); + + TriggerInfo trigger_info; + bool has_trigger = false; + if (region_offset < e->bitTriggers.size()) { + auto &slot = e->bitTriggers[region_offset]; + if (slot) { + trigger_info = *slot; + slot.reset(); + has_trigger = true; + } + } + markBlockSent(e, region_offset); + + stats.prefetchIssued++; + + rrIndex = (idx + 1) % static_cast(n); + + // construct AddrPriority. Use a simple priority scheme: closer blocks get higher priority. + int prio = static_cast(regionBlks) - static_cast(region_offset); + if (has_trigger) { + addresses.emplace_back(AddrPriority(pf_addr, prio, + trigger_info.pfSourceType == PrefetchSourceType::PF_NONE ? pfSourceType : trigger_info.pfSourceType, trigger_info)); + } else { + addresses.emplace_back(AddrPriority(pf_addr, prio, pfSourceType)); + } + stats.l1Issued++; + DPRINTF(HWPrefetch, "GetPFAddrL1 issued addr=%#lx prio=%d trigger=%d\n", + pf_addr, prio, has_trigger); + return true; + } + + return false; +} +bool +PrefetchFilter::GetPFAddrL2(std::vector &addresses) +{ + stats.l2Calls++; + auto it_begin = table.begin(); + auto it_end = table.end(); + const auto n = std::distance(it_begin, it_end); + DPRINTF(HWPrefetch, "GetPFAddrL2 called. table size: %lu\n", static_cast(n)); + if (n == 0) + return false; + + uint64_t mask = (regionBlks >= 64) ? ~uint64_t(0) : ((uint64_t(1) << regionBlks) - 1); + + for (unsigned i = 0; i < static_cast(n); ++i) { + unsigned idx = (rrIndex + i) % static_cast(n); + auto it = it_begin; + std::advance(it, idx); + Entry *e = &(*it); + + if (!e->isValid()) + continue; + if (!e->paddr_valid) + continue; + if (e->PFlevel != 2) + continue; + uint64_t pending = e->region_bits & (~e->filter_bits) & mask; + if (!pending) + continue; + + unsigned region_offset = 0; + if (e->decr_mode) { + unsigned lz = __builtin_clzll(pending); + unsigned msb = 63 - lz; + region_offset = msb; + } else { + region_offset = __builtin_ctzll(pending); + } + + Addr region_num = e->region_addr; + // Use bit operations to compute: region_num * regionSize + region_offset * blkSize + // Assume regionSize and blkSize are powers of two; compute shift amounts. + unsigned rs_shift = __builtin_ctz(regionSize); + unsigned bs_shift = __builtin_ctz(blkSize); + Addr pf_addr = (region_num << rs_shift) + (Addr(region_offset) << bs_shift); + + TriggerInfo trigger_info; + bool has_trigger = false; + if (region_offset < e->bitTriggers.size()) { + auto &slot = e->bitTriggers[region_offset]; + if (slot) { + trigger_info = *slot; + slot.reset(); + has_trigger = true; + } + } + markBlockSent(e, region_offset); + + stats.prefetchIssued++; + + rrIndex = (idx + 1) % static_cast(n); + + // construct AddrPriority. Use a simple priority scheme: closer blocks get higher priority. + int prio = static_cast(regionBlks) - static_cast(region_offset); + if (has_trigger) { + addresses.emplace_back(AddrPriority(pf_addr, prio, + trigger_info.pfSourceType == PrefetchSourceType::PF_NONE ? pfSourceType : trigger_info.pfSourceType, trigger_info)); + } else { + addresses.emplace_back(AddrPriority(pf_addr, prio, pfSourceType)); + } + addresses.back().pfahead_host = 2; + addresses.back().pfahead = true; + stats.l2Issued++; + DPRINTF(HWPrefetch, "GetPFAddrL2 issued addr=%#lx prio=%d trigger=%d\n", + pf_addr, prio, has_trigger); + return true; + } + + return false; +} +bool +PrefetchFilter::GetPFAddrL3(std::vector &addresses) +{ + stats.l3Calls++; + auto it_begin = table.begin(); + auto it_end = table.end(); + const auto n = std::distance(it_begin, it_end); + DPRINTF(HWPrefetch, "GetPFAddrL3 called. table size: %lu\n", static_cast(n)); + if (n == 0) + return false; + + uint64_t mask = (regionBlks >= 64) ? ~uint64_t(0) : ((uint64_t(1) << regionBlks) - 1); + + for (unsigned i = 0; i < static_cast(n); ++i) { + unsigned idx = (rrIndex + i) % static_cast(n); + auto it = it_begin; + std::advance(it, idx); + Entry *e = &(*it); + + if (!e->isValid()) + continue; + if (!e->paddr_valid) + continue; + if (e->PFlevel != 3) + continue; + uint64_t pending = e->region_bits & (~e->filter_bits) & mask; + if (!pending) + continue; + + unsigned region_offset = 0; + if (e->decr_mode) { + unsigned lz = __builtin_clzll(pending); + unsigned msb = 63 - lz; + region_offset = msb; + } else { + region_offset = __builtin_ctzll(pending); + } + + Addr region_num = e->region_addr; + // Use bit operations to compute: region_num * regionSize + region_offset * blkSize + // Assume regionSize and blkSize are powers of two; compute shift amounts. + unsigned rs_shift = __builtin_ctz(regionSize); + unsigned bs_shift = __builtin_ctz(blkSize); + Addr pf_addr = (region_num << rs_shift) + (Addr(region_offset) << bs_shift); + + TriggerInfo trigger_info; + bool has_trigger = false; + if (region_offset < e->bitTriggers.size()) { + auto &slot = e->bitTriggers[region_offset]; + if (slot) { + trigger_info = *slot; + slot.reset(); + has_trigger = true; + } + } + markBlockSent(e, region_offset); + + stats.prefetchIssued++; + + rrIndex = (idx + 1) % static_cast(n); + + // construct AddrPriority. Use a simple priority scheme: closer blocks get higher priority. + int prio = static_cast(regionBlks) - static_cast(region_offset); + if (has_trigger) { + addresses.emplace_back(AddrPriority(pf_addr, prio, + trigger_info.pfSourceType == PrefetchSourceType::PF_NONE ? pfSourceType : trigger_info.pfSourceType, trigger_info)); + } else { + addresses.emplace_back(AddrPriority(pf_addr, prio, pfSourceType)); + } + addresses.back().pfahead_host = 3; + addresses.back().pfahead = true; + stats.l3Issued++; + DPRINTF(HWPrefetch, "GetPFAddrL3 issued addr=%#lx prio=%d trigger=%d\n", + pf_addr, prio, has_trigger); + return true; + } + + return false; +} + +void +PrefetchFilter::markBlockSent(PrefetchFilter::Entry *e, unsigned blk_idx) +{ + if (!e || blk_idx >= regionBlks) + return; + e->filter_bits |= (uint64_t(1) << blk_idx); + table.accessEntry(e); +} + +void +PrefetchFilter::addRegionBits(PrefetchFilter::Entry *e, uint64_t bits) +{ + if (!e) + return; + e->region_bits |= bits; + table.accessEntry(e); +} + +void +PrefetchFilter::ensureTriggerStorage(PrefetchFilter::Entry &e) +{ + if (e.bitTriggers.size() != regionBlks) { + e.bitTriggers.clear(); + e.bitTriggers.resize(regionBlks); + } +} + +void +PrefetchFilter::storeTriggersForBits(PrefetchFilter::Entry &e, uint64_t bits, + const TriggerInfo *trigger) +{ + if (!trigger || regionBlks == 0 || bits == 0) + return; + + ensureTriggerStorage(e); + + uint64_t remaining = bits; + unsigned limit = (regionBlks > 64) ? 64 : regionBlks; + for (unsigned idx = 0; idx < limit && remaining; ++idx) { + if (remaining & uint64_t(1)) { + PacketPtr pkt = trigger->pkt; + e.bitTriggers[idx] = std::make_unique(*trigger); + } + remaining >>= 1; + } +} + +PrefetchFilter::Entry* +PrefetchFilter::Insert(Addr region_addr, uint64_t region_bits, uint8_t alias_bits, + bool paddr_valid, bool decr_mode, + bool is_secure, uint64_t PFlevel, + const TriggerInfo *trigger) +{ + stats.insertCount++; + Addr tag = regionHashTag(region_addr); + Entry *e = table.findEntry(tag, is_secure); + DPRINTF(HWPrefetch, "Insert called: region=%#lx tag=%#lx bits=%#lx level=%lu,name=%s\n", + region_addr, tag, region_bits, PFlevel, table_name.c_str()); + if (e) { + if (e->region_addr != region_addr) { + DPRINTF(HWPrefetch, "Warning: Insert called with existing entry but different region_addr. existing=%#lx new=%#lx\n", + e->region_addr, region_addr); + stats.hashcollisionCount++; + } + storeTriggersForBits(*e, region_bits, trigger); + e->region_bits |= region_bits; + table.accessEntry(e); + stats.queryHitCount++; + DPRINTF(HWPrefetch, "Insert hit: region=%#lx tag=%#lx bits=%#lx level=%lu\n", + region_addr, tag, region_bits, PFlevel); + //print all entry status + for (const auto &entry : table) { + DPRINTF(HWPrefetch, " Entry: region=%#lx tag=%#lx bits=%#lx filter=%#lx level=%lu valid=%d\n", + entry.region_addr, entry.getTag(), entry.region_bits, + entry.filter_bits, entry.PFlevel, entry.isValid()); + } + return e; + } + stats.replacementCount++; + Entry *victim = table.findVictim(tag); + victim->region_addr = region_addr; + victim->region_bits = region_bits; + victim->filter_bits = 0; + victim->alias_bits = alias_bits; + victim->paddr_valid = true; + victim->decr_mode = decr_mode; + victim->_setSecure(is_secure); + victim->PFlevel = PFlevel; + ensureTriggerStorage(*victim); + for (auto &slot : victim->bitTriggers) { + slot.reset(); + } + storeTriggersForBits(*victim, region_bits, trigger); + + table.insertEntry(tag, is_secure, victim); + DPRINTF(HWPrefetch, "Insert miss: region=%#lx tag=%#lx bits=%#lx level=%lu\n", + region_addr, tag, region_bits, PFlevel); + //print all entry status + for (const auto &entry : table) { + DPRINTF(HWPrefetch, " Entry: region=%#lx tag=%#lx bits=%#lx filter=%#lx level=%lu valid=%d\n", + entry.region_addr, entry.getTag(), entry.region_bits, + entry.filter_bits, entry.PFlevel, entry.isValid()); + + } + return victim; +} + +// uint64_t +// PrefetchFilter::pendingBlocks(PrefetchFilter::Entry *e) const +// { +// if (!e) +// return 0; +// return e->region_bits & static_cast(~e->filter_bits); +// } + // region-hash-tag implementation per chisel spec +Addr +PrefetchFilter::regionHashTag(Addr vaddr) const +{ + Addr low_mask = ((Addr(1) << REGION_ADDR_RAW_WIDTH) - 1); + Addr low = vaddr & low_mask; + + unsigned high_low = REGION_ADDR_RAW_WIDTH; + unsigned high_bits = 3 * vaddrHashWidth; + Addr high = (vaddr >> high_low) & ((Addr(1) << high_bits) - 1); + + Addr seg0 = high & ((Addr(1) << vaddrHashWidth) - 1); + Addr seg1 = (high >> vaddrHashWidth) & ((Addr(1) << vaddrHashWidth) - 1); + Addr seg2 = (high >> (2 * vaddrHashWidth)) & ((Addr(1) << vaddrHashWidth) - 1); + Addr high_hash = seg0 ^ seg1 ^ seg2; + + Addr tag = (high_hash << REGION_ADDR_RAW_WIDTH) | low; + return tag; +} +bool +PrefetchFilter::hasPFRequestsInBuffer() +{ + auto it_begin = table.begin(); + auto it_end = table.end(); + const auto n = std::distance(it_begin, it_end); + if (n == 0) + return false; + DPRINTF(HWPrefetch, "hasPFRequestsInBuffer called. table size: %lu,name=%s\n", static_cast(n), table_name.c_str()); + //print all entry status + for (const auto &entry : table) { + DPRINTF(HWPrefetch, " Entry: region=%#lx tag=%#lx bits=%#lx filter=%#lx level=%lu valid=%d\n", + entry.region_addr, entry.getTag(), entry.region_bits, + entry.filter_bits, entry.PFlevel, entry.isValid()); + } + uint64_t mask = (regionBlks >= 64) ? ~uint64_t(0) : ((uint64_t(1) << regionBlks) - 1); + + for (unsigned i = 0; i < static_cast(n); ++i) { + unsigned idx = (rrIndex + i) % static_cast(n); + auto it = it_begin; + std::advance(it, idx); + Entry *e = &(*it); + + if (!e->isValid()) + continue; + if (!e->paddr_valid) + continue; + uint64_t pending = e->region_bits & (~e->filter_bits) & mask; + if (pending) + return true; + } + + return false; +} +} // namespace prefetch +} // namespace gem5 diff --git a/src/mem/cache/prefetch/prefetch_filter.hh b/src/mem/cache/prefetch/prefetch_filter.hh new file mode 100644 index 0000000000..d69d021c52 --- /dev/null +++ b/src/mem/cache/prefetch/prefetch_filter.hh @@ -0,0 +1,186 @@ +// PrefetchFilter header moved out of sms.hh + +#ifndef GEM5_PREFETCH_FILTER_HH +#define GEM5_PREFETCH_FILTER_HH + +#include +#include +#include +#include + +#include "base/statistics.hh" +#include "base/types.hh" +#include "mem/cache/prefetch/associative_set.hh" +#include "mem/cache/tags/tagged_entry.hh" +#include "mem/cache/prefetch/queued.hh" // for AddrPriority and PrefetchSourceType + +namespace gem5 { +namespace prefetch { +class BaseIndexingPolicy; +class Base; +namespace replacement_policy { class Base; } +using AddrPriority = gem5::prefetch::Queued::AddrPriority; +class PrefetchFilter +{ + public: + using TriggerInfo = Base::PFtriggerInfo; + + struct Entry : public TaggedEntry { + Addr region_addr; // region number (Vaddr[38:10] or Paddr[35:10] when paddr_valid) + uint64_t region_bits; // which blocks in region should be prefetched (runtime width) + uint64_t filter_bits; // which prefetch requests have been issued + uint8_t alias_bits; // Vaddr[13:12] for VIPT aliasing,not needed + bool paddr_valid; // true if region_addr is physical + bool decr_mode; // 1 if decrementing prefetch mode + uint64_t PFlevel; // prefetch level for this region, L1/L2/L3 + std::vector> bitTriggers; + + Entry() + : TaggedEntry(), region_addr(0), region_bits(0), filter_bits(0), alias_bits(0), + paddr_valid(false), decr_mode(false), PFlevel(0) {} + + Entry(const Entry &other) + : TaggedEntry(other), + region_addr(other.region_addr), + region_bits(other.region_bits), + filter_bits(other.filter_bits), + alias_bits(other.alias_bits), + paddr_valid(other.paddr_valid), + decr_mode(other.decr_mode), + PFlevel(other.PFlevel) + { + copyTriggers(other); + } + + Entry& operator=(const Entry &other) + { + if (this != &other) { + TaggedEntry::operator=(other); + region_addr = other.region_addr; + region_bits = other.region_bits; + filter_bits = other.filter_bits; + alias_bits = other.alias_bits; + paddr_valid = other.paddr_valid; + decr_mode = other.decr_mode; + PFlevel = other.PFlevel; + copyTriggers(other); + } + return *this; + } + + Entry(Entry &&) noexcept = default; + Entry& operator=(Entry &&) noexcept = default; + ~Entry() = default; + + void _setSecure(bool is_secure) { + if (is_secure) TaggedEntry::setSecure(); + } + + private: + void copyTriggers(const Entry &other) + { + bitTriggers.clear(); + bitTriggers.reserve(other.bitTriggers.size()); + for (const auto &src : other.bitTriggers) { + if (src) { + PacketPtr pkt = src->pkt; + bitTriggers.emplace_back( + std::make_unique(pkt, *(src->pfi_old))); + } else { + bitTriggers.emplace_back(nullptr); + } + } + } + }; + + static constexpr unsigned DEFAULT_REGION_SIZE = 1024; // 1KB + + PrefetchFilter(gem5::BaseIndexingPolicy *idx_policy, gem5::replacement_policy::Base *rpl_policy, + unsigned entries = 16, unsigned region_size = DEFAULT_REGION_SIZE, + unsigned blk_size = 64, statistics::Group *parent = nullptr, + unsigned vaddr_hash_width = 2, PrefetchSourceType pf_source_type = PrefetchSourceType::PF_NONE, + const std::string &name = "prefetch_filter"); + ~PrefetchFilter(); + + // Lookup entry by virtual address (uses VA->region conversion and TaggedEntry tag) + Entry* findByVaddr(Addr vaddr, bool is_secure = false); + + // Lookup by region number (region = vaddr / REGION_SIZE) + Entry* findByRegion(Addr region, bool is_secure = false); + + // Allocate or replace an entry for this vaddr/region. Returns the entry pointer. + Entry* allocateForVaddr(Addr vaddr, bool is_secure = false, Addr region_addr = 0); + + // Mark that a specific block index in region has been issued as prefetch. + void markBlockSent(Entry *e, unsigned blk_idx); + + // Add region_bits (OR) to the entry, marking predicted blocks for prefetch + void addRegionBits(Entry *e, uint64_t bits); + + // Insert or update an entry for a given region. If an entry for `region` + // exists, OR `region_bits` into the existing entry and return it. If it + // doesn't, allocate/overwrite a victim entry and initialize its fields + // (filter_bits defaults to 0) and insert it into the table. Returns the + // updated or newly inserted Entry pointer. + Entry* Insert(Addr region_addr = 0, uint64_t region_bits = 0, uint8_t alias_bits = 0, + bool paddr_valid = false, bool decr_mode = false, + bool is_secure = false, uint64_t PFlevel = 1, + const TriggerInfo *trigger = nullptr); + // Get blocks still pending prefetch (region_bits & ~filter_bits) + uint64_t pendingBlocks(Entry *e) const; + + // Compute alias bits from virtual address + static uint8_t aliasFromVaddr(Addr vaddr) { return (vaddr >> 12) & 0x3; } + + private: + AssociativeSet table; + unsigned regionSize; + unsigned blkSize; + unsigned regionBlks; + unsigned rrIndex{0}; + const unsigned REGION_ADDR_RAW_WIDTH; + const unsigned vaddrHashWidth; // width for vaddr hash (per chisel spec) + + void ensureTriggerStorage(Entry &e); + void storeTriggersForBits(Entry &e, uint64_t bits, const TriggerInfo *trigger); + + // Compute region-hash tag as described by chisel: + // low = region_tag[BLK_ADDR_RAW_WIDTH-1:0] + // high = region_tag[BLK_ADDR_RAW_WIDTH-1+3*VADDR_HASH_WIDTH : BLK_ADDR_RAW_WIDTH] + // high_hash = xor of 3 segments of VADDR_HASH_WIDTH bits from 'high' + // tag = concat(high_hash, low) + Addr regionHashTag(Addr vaddr) const; + + public: + // Statistics for the PrefetchFilter. Parent should be provided by the + // owner (e.g. XSCompositePrefetcher::stats) so counters are exposed in + // gem5's statistics framework. + struct Stats : public statistics::Group { + Stats(statistics::Group *parent, const std::string &name); + statistics::Scalar insertCount; + statistics::Scalar queryHitCount; + statistics::Scalar prefetchIssued; + statistics::Scalar replacementCount; + statistics::Scalar l1Calls; + statistics::Scalar l1Issued; + statistics::Scalar l2Calls; + statistics::Scalar l2Issued; + statistics::Scalar l3Calls; + statistics::Scalar l3Issued; + statistics::Scalar hashcollisionCount; + } stats; + PrefetchSourceType pfSourceType; + const std::string table_name; + public: + // Select next prefetch address from the filter table using a round-robin + // arbiter. Returns true and pushes an AddrPriority into `addresses` if a candidate is found. + bool GetPFAddrL1(std::vector &addresses); + bool GetPFAddrL2(std::vector &addresses); + bool GetPFAddrL3(std::vector &addresses); + bool hasPFRequestsInBuffer(); +}; + +} // namespace prefetch +} // namespace gem5 + +#endif // GEM5_PREFETCH_FILTER_HH diff --git a/src/mem/cache/prefetch/queued.cc b/src/mem/cache/prefetch/queued.cc index 7e7f1b6b2f..5fbeb76b68 100644 --- a/src/mem/cache/prefetch/queued.cc +++ b/src/mem/cache/prefetch/queued.cc @@ -38,10 +38,12 @@ #include "mem/cache/prefetch/queued.hh" #include +#include #include "arch/generic/tlb.hh" #include "base/logging.hh" #include "base/trace.hh" +#include "cmc.hh" #include "debug/HWPrefetch.hh" #include "debug/HWPrefetchOther.hh" #include "debug/HWPrefetchQueue.hh" @@ -129,7 +131,14 @@ Queued::Queued(const QueuedPrefetcherParams &p) tlbReqEvent( [this]{ processMissingTranslations(queueSize); }, name()), - statsQueued(this) + statsQueued(this), + usePFBuffer(p.use_pf_buffer), + PFRequestBuffer(), + max_pf_buffer_size(p.max_pf_buffer_size), + PFReqSendEvent( + [this]{ PFSendEventWrapper(); }, + name()) + { } @@ -243,10 +252,18 @@ Queued::notify(const PacketPtr &pkt, const PrefetchInfo &pfi) // Calculate prefetches given this access std::vector addresses; // if (!pkt->coalescingMSHR) { // hit to Other cpu access + pfi.setTriggerInfo(pkt); calculatePrefetch(pfi, addresses, pfi.isCacheMiss() && (late_in_mshr || late_in_pfq), pf_source, pkt->coalescingMSHR); // } - + if (usePFBuffer) { + //PFs supposed to be stored in buffer,just trigger PF send event + if (!PFReqSendEvent.scheduled()) { + //even if this cycle has trained,we assume it take 1 cycle to generate PFs + schedule(PFReqSendEvent, nextCycle()); + } + return; + } // Get the maximu number of prefetches that we are allowed to generate size_t max_pfs = getMaxPermittedPrefetches(addresses.size()); @@ -283,7 +300,56 @@ Queued::notify(const PacketPtr &pkt, const PrefetchInfo &pfi) } } } +void +Queued::PFSendEventWrapper() +{ + std::vector addresses; + GetPFRequestsFromBuffer(addresses); + + // there may be more than 1 req in addresses because we are trying to allow max 1 PF to every cache level + // assert(addresses.size()==1); + // Get the maximu number of prefetches that we are allowed to generate + size_t max_pfs = getMaxPermittedPrefetches(addresses.size()); + + // Queue up generated prefetches + size_t num_pfs = 0; + for (AddrPriority& addr_prio : addresses) { + + PacketPtr pkt = addr_prio.pf_trigger_info.pkt; + PrefetchInfo pfi = PrefetchInfo(*addr_prio.pf_trigger_info.pfi_old); + //override address's prio to 1 + addr_prio.priority = 1; + // Block align prefetch address + addr_prio.addr = blockAddress(addr_prio.addr); + + if (!samePage(addr_prio.addr, pfi.getAddr())) { + statsQueued.pfSpanPage += 1; + + if (hasBeenPrefetched(pkt->getAddr(), pkt->isSecure())) { + statsQueued.pfUsefulSpanPage += 1; + } + } + bool can_cross_page = (tlb != nullptr); + if (can_cross_page || samePage(addr_prio.addr, pfi.getAddr())) { + PrefetchInfo new_pfi(pfi, addr_prio.addr); + new_pfi.setXsMetadata(Request::XsMetadata(addr_prio.pfSource,addr_prio.depth)); + statsQueued.pfIdentified++; + DPRINTF(HWPrefetch, "Found a pf candidate addr: %#x, " + "inserting into prefetch queue.\n", new_pfi.getAddr()); + insert(pkt, new_pfi, addr_prio); + num_pfs += 1; + if (num_pfs == max_pfs) { + break; + } + } else { + DPRINTF(HWPrefetch, "Ignoring page crossing prefetch.\n"); + } + } + if (hasPFRequestsInBuffer() && !PFReqSendEvent.scheduled()) { + schedule(PFReqSendEvent, nextCycle()); // schedule next PF send event + } +} bool Queued::hasPendingPacket() { @@ -332,8 +398,13 @@ Queued::QueuedStats::QueuedStats(statistics::Group *parent) ADD_STAT(pfSpanPage, statistics::units::Count::get(), "number of prefetches that crossed the page"), ADD_STAT(pfUsefulSpanPage, statistics::units::Count::get(), - "number of prefetches that is useful and crossed the page") -{ + "number of prefetches that is useful and crossed the page"), + ADD_STAT(pfRemovedFull_srcs, statistics::units::Count::get(), + "src distribute of Removedfull prefetch") +{ using namespace statistics; + pfRemovedFull_srcs + .init(NUM_PF_SOURCES) + .flags(total); } @@ -384,7 +455,7 @@ Queued::translationComplete(DeferredPacket *dp, bool failed) it->translationRequest->getPaddr()); Addr target_paddr = it->translationRequest->getPaddr(); // check if this prefetch is already redundant - if (cacheSnoop && (inCache(target_paddr, it->pfInfo.isSecure()) || + if (cacheSnoop && queueFilter && (inCache(target_paddr, it->pfInfo.isSecure()) || inMissQueue(target_paddr, it->pfInfo.isSecure()))) { statsQueued.pfInCache++; DPRINTF(HWPrefetch, "Dropping redundant in " @@ -554,7 +625,7 @@ Queued::insert(const PacketPtr &pkt, PrefetchInfo &new_pfi, const AddrPriority & return; } } - if (has_target_pa && cacheSnoop && + if (has_target_pa && cacheSnoop && queueFilter && (inCache(target_paddr, new_pfi.isSecure()) || inMissQueue(target_paddr, new_pfi.isSecure()))) { statsQueued.pfInCache++; @@ -650,6 +721,8 @@ Queued::addToQueue(std::list &queue, } DPRINTF(HWPrefetch, "%s full (sz=%lu), removing lowest priority oldest packet, addr: %#x\n", queue_name, queue.size(), it->pfInfo.getAddr()); + statsQueued.pfRemovedFull_srcs[it->pfInfo.getXsMetadata().prefetchSource]++; + if (&queue == &pfq || !it->ongoingTranslation){ delete it->pkt; queue.erase(it); diff --git a/src/mem/cache/prefetch/queued.hh b/src/mem/cache/prefetch/queued.hh index 13d9885acb..b8270f2875 100644 --- a/src/mem/cache/prefetch/queued.hh +++ b/src/mem/cache/prefetch/queued.hh @@ -57,6 +57,8 @@ GEM5_DEPRECATED_NAMESPACE(Prefetcher, prefetch); namespace prefetch { +using PFTriggerInfo = Base::PFtriggerInfo; + class Queued : public Base { public: @@ -70,6 +72,7 @@ class Queued : public Base bool pfahead = false; int depth=0; PrefetchSourceType pfSource; + PFTriggerInfo pf_trigger_info{}; PrefetchCmd(Addr a, int32_t p) : addr(a), priority(p), isVA(true), isBOP(false) { panic("PrefetchCmd: no source specified"); @@ -78,6 +81,10 @@ class Queued : public Base : addr(a), priority(p), isVA(true), isBOP(false), pfSource(src) { } + PrefetchCmd(Addr a, int32_t p, PrefetchSourceType src, PFTriggerInfo pf_info) + : addr(a), priority(p), isVA(true), isBOP(false), pfSource(src), pf_trigger_info(pf_info) + { + } PrefetchCmd(Addr a, int32_t p, PrefetchSourceType src, bool va, bool bop) : addr(a), priority(p), isVA(va), isBOP(bop), pfSource(src) { @@ -219,6 +226,7 @@ class Queued : public Base statistics::Scalar pfRemovedFull; statistics::Scalar pfSpanPage; statistics::Scalar pfUsefulSpanPage; + statistics::Vector pfRemovedFull_srcs; } statsQueued; public: @@ -306,6 +314,36 @@ class Queued : public Base void pfHitNotify(float accuracy, PrefetchSourceType pf_source, const PacketPtr &pkt) override { } void offloadToDownStream() override; + protected: + const bool usePFBuffer{false}; + std::list PFRequestBuffer; + const int max_pf_buffer_size{8}; + //here we implement a buffer that drop the pf requests when the buffer is full + virtual void InsertPFRequestToBuffer(const AddrPriority &addr_prio) { + if (PFRequestBuffer.size() < max_pf_buffer_size) { + PFRequestBuffer.push_back(addr_prio); + }else{ + PFRequestBuffer.pop_front(); + PFRequestBuffer.push_back(addr_prio); + } + }; + /** Event to handle the delay queue processing */ + void PFSendEventWrapper(); + EventFunctionWrapper PFReqSendEvent; + + public: + virtual bool hasPFRequestsInBuffer() { + return !PFRequestBuffer.empty(); + } + virtual bool GetPFRequestsFromBuffer(std::vector &addresses) { + if (PFRequestBuffer.empty()) { + return false; + } + AddrPriority addr_prio = PFRequestBuffer.front(); + PFRequestBuffer.pop_front(); + addresses.push_back(addr_prio); + return true; + }; }; } // namespace prefetch diff --git a/src/mem/cache/prefetch/signature_path.cc b/src/mem/cache/prefetch/signature_path.cc index 981c1d153c..2e189008fa 100644 --- a/src/mem/cache/prefetch/signature_path.cc +++ b/src/mem/cache/prefetch/signature_path.cc @@ -97,9 +97,10 @@ SignaturePath::PatternEntry::getStrideEntry(stride_t stride) } void -SignaturePath::addPrefetch(Addr ppn, stride_t last_block, stride_t delta, double path_confidence, - signature_t signature, bool is_secure, std::vector &addresses, - boost::compute::detail::lru_cache &filter) +SignaturePath::addPrefetch(const PrefetchInfo &pfi, Addr ppn, + stride_t last_block, stride_t delta, double path_confidence, + signature_t signature, bool is_secure, std::vector &addresses, + boost::compute::detail::lru_cache &filter) { stride_t block = last_block + delta; @@ -134,7 +135,7 @@ SignaturePath::addPrefetch(Addr ppn, stride_t last_block, stride_t delta, double new_addr += pf_block * (Addr)blkSize; DPRINTF(SPP, "Queuing prefetch to %#x, with stride of %#x(%u) blocks.\n", new_addr, delta, delta); - sendPFWithFilter(new_addr, addresses, 0, filter); + sendPFWithFilter(pfi, new_addr, addresses, 0, filter); } void @@ -301,7 +302,7 @@ SignaturePath::calculatePrefetch(const PrefetchInfo &pfi, std::vector init_addr_size; return sent; } void -SignaturePath::auxiliaryPrefetcher(Addr ppn, stride_t current_block, bool is_secure, +SignaturePath::auxiliaryPrefetcher(const PrefetchInfo &pfi, Addr ppn, stride_t current_block, bool is_secure, std::vector &addresses, boost::compute::detail::lru_cache &filter) { if (addresses.empty()) { // Enable the next line prefetcher if no prefetch candidates are found - addPrefetch(ppn, current_block, 1, 0.0 /* unused*/, 0 /* unused */, + addPrefetch(pfi, ppn, current_block, 1, 0.0 /* unused*/, 0 /* unused */, is_secure, addresses, filter); } } bool -SignaturePath::sendPFWithFilter(Addr addr, std::vector &addresses, int prio, +SignaturePath::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, boost::compute::detail::lru_cache &filter) { + InsertPFRequestToBuffer(AddrPriority(addr, prio, PrefetchSourceType::SPP, pfi.trigger_info)); if (filter.contains(addr)) { DPRINTF(SPP, "Skip recently prefetched: %lx\n", addr); return false; diff --git a/src/mem/cache/prefetch/signature_path.hh b/src/mem/cache/prefetch/signature_path.hh index f905d8fa35..cb0c6a04af 100644 --- a/src/mem/cache/prefetch/signature_path.hh +++ b/src/mem/cache/prefetch/signature_path.hh @@ -179,7 +179,8 @@ class SignaturePath : public Queued * @param is_secure whether this page is inside the secure memory area * @param addresses addresses to prefetch will be added to this vector */ - void addPrefetch(Addr ppn, stride_t last_block, stride_t delta, double path_confidence, signature_t signature, + void addPrefetch(const PrefetchInfo &pfi, Addr ppn, stride_t last_block, stride_t delta, + double path_confidence, signature_t signature, bool is_secure, std::vector &addresses, boost::compute::detail::lru_cache &filter); @@ -265,7 +266,7 @@ class SignaturePath : public Queued * @param updated_filter_entries set of addresses containing these that * their filter has been updated, if this call updates a new entry */ - virtual void auxiliaryPrefetcher(Addr ppn, stride_t current_block, bool is_secure, + virtual void auxiliaryPrefetcher(const PrefetchInfo &pfi, Addr ppn, stride_t current_block, bool is_secure, std::vector &addresses, boost::compute::detail::lru_cache &filter); @@ -299,7 +300,7 @@ class SignaturePath : public Queued boost::compute::detail::lru_cache &filter, int32_t &best_block_offset); private: - bool sendPFWithFilter(Addr addr, std::vector &addresses, int prio, + bool sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, boost::compute::detail::lru_cache &filter); unsigned sPageBytes; diff --git a/src/mem/cache/prefetch/signature_path_v2.hh b/src/mem/cache/prefetch/signature_path_v2.hh index 7c5fbb7fcf..5d12a6d1cf 100644 --- a/src/mem/cache/prefetch/signature_path_v2.hh +++ b/src/mem/cache/prefetch/signature_path_v2.hh @@ -86,8 +86,9 @@ class SignaturePathV2 : public SignaturePath * In this version of the Signature Path Prefetcher, there is no auxiliary * prefetcher, so this function does not perform any actions. */ - void auxiliaryPrefetcher(Addr ppn, stride_t current_block, bool is_secure, std::vector &addresses, - boost::compute::detail::lru_cache &filter) override + void auxiliaryPrefetcher(const PrefetchInfo &pfi, Addr ppn, stride_t current_block, bool is_secure, + std::vector &addresses, + boost::compute::detail::lru_cache &filter) override {} virtual void handlePageCrossingLookahead(signature_t signature, diff --git a/src/mem/cache/prefetch/sms.cc b/src/mem/cache/prefetch/sms.cc index 359f88996c..092d0a9082 100644 --- a/src/mem/cache/prefetch/sms.cc +++ b/src/mem/cache/prefetch/sms.cc @@ -1,4 +1,7 @@ #include "mem/cache/prefetch/sms.hh" +#include +#include +#include #include "base/stats/group.hh" #include "debug/BOPOffsets.hh" @@ -10,10 +13,14 @@ namespace gem5 namespace prefetch { +// PrefetchFilter implementation moved to prefetch_filter.{hh,cc} + + XSCompositePrefetcher::XSCompositePrefetcher(const XSCompositePrefetcherParams &p) : Queued(p), regionSize(p.region_size), regionBlks(p.region_size / p.block_size), + enableTrainFilter(p.enable_train_filter), act(p.act_entries, p.act_entries, p.act_indexing_policy, p.act_replacement_policy, ACTEntry(SatCounter8(2, 1))), re_act(p.re_act_entries, p.re_act_entries, p.re_act_indexing_policy, @@ -26,6 +33,17 @@ XSCompositePrefetcher::XSCompositePrefetcher(const XSCompositePrefetcherParams & phtPFLevel(std::min(p.pht_pf_level, (int) 3)), stats(this), pfBlockLRUFilter(pfFilterSize), + sms_pfFilter(p.sms_filter_indexing_policy, p.sms_filter_replacement_policy, p.sms_filter_entries, + p.region_size, p.block_size, this, p.vaddr_hash_width, + PrefetchSourceType::SPht, "sms_pfFilter"), + stridestream_pfFilter_l1(p.stridestream_L1_filter_indexing_policy, p.stridestream_L1_filter_replacement_policy, + p.stridestream_L1_filter_entries, p.region_size, p.block_size, this, + p.vaddr_hash_width, PrefetchSourceType::SStream, + "stridestream_pfFilter_l1"), + stridestream_pfFilter_l2l3(p.stridestream_L2L3_filter_indexing_policy, p.stridestream_L2L3_filter_replacement_policy, + p.stridestream_L2L3_filter_entries, p.region_size, p.block_size, this, + p.vaddr_hash_width, PrefetchSourceType::SStream, + "stridestream_pfFilter_l2l3"), pfPageLRUFilter(pfPageFilterSize), pfPageLRUFilterL2(pfPageFilterSize), pfPageLRUFilterL3(pfPageFilterSize), @@ -50,7 +68,11 @@ XSCompositePrefetcher::XSCompositePrefetcher(const XSCompositePrefetcherParams & enableOpt(p.enable_opt), enableXsstream(p.enable_xsstream), phtEarlyUpdate(p.pht_early_update), - neighborPhtUpdate(p.neighbor_pht_update) + neighborPhtUpdate(p.neighbor_pht_update), + phtSentPrefetch(), + phtReqSendEvent([this]{ phtSendEventWrapper(); }, + name()), + BOPPFlevel(p.bop_pf_level) { assert(largeBOP); assert(smallBOP); @@ -80,6 +102,20 @@ XSCompositePrefetcher::XSCompositePrefetcher(const XSCompositePrefetcherParams & DPRINTF(XSCompositePrefetcher, "SMS: region_size: %d regionBlks: %d\n", regionSize, regionBlks); + if (Xsstream) + { + Xsstream->stridestream_pfFilter_l1 = &this->stridestream_pfFilter_l1; + Xsstream->stridestream_pfFilter_l2l3 = &this->stridestream_pfFilter_l2l3; + } + if (Sstride) + { + Sstride->stridestream_pfFilter_l1 = &this->stridestream_pfFilter_l1; + Sstride->stridestream_pfFilter_l2l3 = &this->stridestream_pfFilter_l2l3; + } + assert(phtSentPrefetch.size() == 0); + for(unsigned i = 0; i < 3; i++) + phtSentPrefetch.push_back(phtsentInfo()); + } void @@ -90,6 +126,7 @@ XSCompositePrefetcher::calculatePrefetch(const PrefetchInfo &pfi, std::vectorcalculatePrefetch(pfi, addresses, streamlatenum); + stats.streamTrainCount++; + } act_match_entry = actLookup(pfi, is_active_page, enter_new_region, is_first_shot); if (enableOpt){ assert(Opt); @@ -262,6 +301,7 @@ XSCompositePrefetcher::calculatePrefetch(const PrefetchInfo &pfi, std::vectorcalculatePrefetch(pfi, addresses, late, pf_source, miss_repeat, enter_new_region, is_first_shot, stride_pf_addr, learned_bop_offset); if (learned_bop_offset != 0) @@ -408,6 +448,7 @@ XSCompositePrefetcher::updatePht(XSCompositePrefetcher::ACTEntry *act_entry, Add } pht_entry->pc = act_entry->pc; act_entry->hasIncreasedPht = true; + pht_entry->decr_mode = act_entry->inBackwardMode; } else { return; } @@ -420,6 +461,7 @@ XSCompositePrefetcher::updatePht(XSCompositePrefetcher::ACTEntry *act_entry, Add pht_entry->hist[i].reset(); } pht_entry->pc = act_entry->pc; + pht_entry->decr_mode = act_entry->inBackwardMode; } pht.accessEntry(pht_entry); @@ -506,8 +548,13 @@ XSCompositePrefetcher::phtLookup(const Base::PrefetchInfo &pfi, std::vectorhist[i + regionBlks - 1].calcSaturation() > 0.5) { Addr pf_tgt_addr = blk_addr + (i + 1) * blkSize; - sendPFWithFilter(pfi, pf_tgt_addr, addresses, priority--, PrefetchSourceType::SPht, phtPFLevel); - found = true; + if(regionAddress(pf_tgt_addr) == region_addr) { + region_bit_cur |= (uint64_t(1) << regionOffset(pf_tgt_addr)); + sendPFWithFilter(pfi, pf_tgt_addr, addresses, priority--, PrefetchSourceType::SPht, phtPFLevel); + found = true; + } + } + } + for (int i = regionBlks - 2, j = 1; i >= 0; i--, j++) { + if (pht_entry->hist[i].calcSaturation() > 0.5) { + Addr pf_tgt_addr = blk_addr - j * blkSize; + if(regionAddress(pf_tgt_addr) == region_addr) { + region_bit_cur |= (uint64_t(1) << regionOffset(pf_tgt_addr)); + sendPFWithFilter(pfi, pf_tgt_addr, addresses, priority--, PrefetchSourceType::SPht, phtPFLevel); + found = true; + } + } + } + if(found){ + if(phtSentPrefetch[0].valid){ + stats.smsCurRegionoverride++; + } + phtSentPrefetch[0] = phtsentInfo(region_addr, region_bit_cur ,0, true,pht_entry->decr_mode,secure,phtPFLevel, &pfi.trigger_info); + phtSentPrefetch[0].trigger.pfSourceType = PrefetchSourceType::SPht; + } + found = false; + for (uint8_t i = 0; i < regionBlks - 1; i++) { + if (pht_entry->hist[i + regionBlks - 1].calcSaturation() > 0.5) { + Addr pf_tgt_addr = blk_addr + (i + 1) * blkSize; + if(regionAddress(pf_tgt_addr) != region_addr) { + region_inc_addr = regionAddress(pf_tgt_addr); + region_bit_inc |= (uint64_t(1) << regionOffset(pf_tgt_addr)); + sendPFWithFilter(pfi, pf_tgt_addr, addresses, priority--, PrefetchSourceType::SPht, phtPFLevel); + found = true; + } + } + } + if(found){ + if(phtSentPrefetch[1].valid){ + stats.smsIncrRegionoverride++; } + phtSentPrefetch[1] = phtsentInfo(region_inc_addr, region_bit_inc ,0, true,pht_entry->decr_mode,secure,phtPFLevel, &pfi.trigger_info); + phtSentPrefetch[1].trigger.pfSourceType = PrefetchSourceType::SPht; } + + found = false; for (int i = regionBlks - 2, j = 1; i >= 0; i--, j++) { if (pht_entry->hist[i].calcSaturation() > 0.5) { Addr pf_tgt_addr = blk_addr - j * blkSize; - sendPFWithFilter(pfi, pf_tgt_addr, addresses, priority--, PrefetchSourceType::SPht, phtPFLevel); - found = true; + if(regionAddress(pf_tgt_addr) != region_addr) { + region_dec_addr = regionAddress(pf_tgt_addr); + region_bit_dec |= (uint64_t(1) << regionOffset(pf_tgt_addr)); + sendPFWithFilter(pfi, pf_tgt_addr, addresses, priority--, PrefetchSourceType::SPht, phtPFLevel); + found = true; + } } } + if(found){ + if(phtSentPrefetch[2].valid){ + stats.smsDecrRegionoverride++; + } + phtSentPrefetch[2] = phtsentInfo(region_dec_addr, region_bit_dec ,0, true,pht_entry->decr_mode,secure,phtPFLevel, &pfi.trigger_info); + phtSentPrefetch[2].trigger.pfSourceType = PrefetchSourceType::SPht; + } + if (!phtReqSendEvent.scheduled()){ + phtSendEventWrapper(); + } + DPRINTF(XSCompositePrefetcher, "pht entry pattern:\n"); for (uint8_t i = 0; i < 2 * (regionBlks - 1); i++) { DPRINTFR(XSCompositePrefetcher, "%.2f ", pht_entry->hist[i].calcSaturation()); @@ -547,20 +650,31 @@ bool XSCompositePrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, PrefetchSourceType src, int ahead_level) { + // Count generated prefetch + prefetchStats.pfGenerated++; + if (ahead_level < 2 && pfPageLRUFilter.contains(regionAddress(addr))) { DPRINTF(XSCompositePrefetcher, "Skip recently L1 prefetched page: %lx\n", regionAddress(addr)); + // Count filtered prefetch + prefetchStats.pfFiltered++; return false; } else if (ahead_level == 2 && pfPageLRUFilterL2.contains(regionAddress(addr))) { DPRINTF(XSCompositePrefetcher, "Skip recently L2 prefetched page: %lx\n", regionAddress(addr)); + // Count filtered prefetch + prefetchStats.pfFiltered++; return false; } else if (ahead_level == 3 && pfPageLRUFilterL3.contains(regionAddress(addr))) { DPRINTF(XSCompositePrefetcher, "Skip recently L3 prefetched page: %lx\n", regionAddress(addr)); + // Count filtered prefetch + prefetchStats.pfFiltered++; return false; } else if (pfBlockLRUFilter.contains(addr)) { DPRINTF(XSCompositePrefetcher, "Skip recently prefetched: %lx\n", addr); + // Count filtered prefetch + prefetchStats.pfFiltered++; return false; } else { @@ -587,6 +701,7 @@ void XSCompositePrefetcher::sendStreamPF(const PrefetchInfo &pfi, Addr pf_tgt_addr, std::vector &addresses, boost::compute::detail::lru_cache &Filter, bool decr, int pf_level) { + uint64_t region_bit = 0; Addr pf_tgt_region = regionAddress(pf_tgt_addr); Addr pf_tgt_offset = regionOffset(pf_tgt_addr); PrefetchSourceType stream_type = PrefetchSourceType::SStream; @@ -597,10 +712,21 @@ XSCompositePrefetcher::sendStreamPF(const PrefetchInfo &pfi, Addr pf_tgt_addr, s DPRINTF(XSCompositePrefetcher, "tgt addr: %x, offset: %d ,page: %lx\n", pf_tgt_addr, pf_tgt_offset, pf_tgt_region); for (int i = 0; i < regionBlks; i++) { Addr cur = pf_tgt_region * regionSize + i * blkSize; + region_bit |= (uint64_t(1) << regionOffset(cur)); sendPFWithFilter(pfi, cur, addresses, regionBlks - i, stream_type, pf_level); DPRINTF(XSCompositePrefetcher, "pf addr: %x [%d] pf_level %d\n", cur, i, pf_level); fatal_if(i < 0, "i < 0\n"); } + //use for act to insert PFfilter + pfi.setTriggerInfo_PFsrc(stream_type); + if (pf_level > 1) { + stridestream_pfFilter_l2l3.Insert(regionAddress(pf_tgt_addr), + region_bit,0,true,decr,pfi.isSecure(),pf_level, &pfi.trigger_info); + } else { + stridestream_pfFilter_l1.Insert(regionAddress(pf_tgt_addr), + region_bit,0,true,decr,pfi.isSecure(),pf_level, &pfi.trigger_info); + } + Filter.insert(pf_tgt_region, 0); } @@ -640,7 +766,13 @@ XSCompositePrefetcher::XSCompositeStats::XSCompositeStats(statistics::Group *par ADD_STAT(allCntNum, statistics::units::Count::get(), "victim act access num"), ADD_STAT(actMNum, statistics::units::Count::get(), "victim act match num"), ADD_STAT(refillNotifyCount, statistics::units::Count::get(), "refill notify count"), - ADD_STAT(bopTrainCount, statistics::units::Count::get(), "bop train count") + ADD_STAT(bopTrainCount, statistics::units::Count::get(), "bop train count"), + ADD_STAT(smsCurRegionoverride, statistics::units::Count::get(), "sms current region override prefetches"), + ADD_STAT(smsIncrRegionoverride, statistics::units::Count::get(), "sms increased region override prefetches"), + ADD_STAT(smsDecrRegionoverride, statistics::units::Count::get(), "sms decreased region override prefetches"), + ADD_STAT(strideTrainCount, statistics::units::Count::get(), "stride train count"), + ADD_STAT(streamTrainCount, statistics::units::Count::get(), "stream train count"), + ADD_STAT(totalTrainCount, statistics::units::Count::get(), "total train count") { } @@ -661,6 +793,101 @@ XSCompositePrefetcher::setParentInfo(System *sys, ProbeManager *pm, CacheAccesso if (ipcp) ipcp->setParentInfo(sys, pm, _cache, blk_size); } - +bool XSCompositePrefetcher::GetPFRequestsFromBuffer(std::vector &addresses) +{ + //here we decide which to send for this cycle + //L1 Streamstride>berti>SMS>CMC>learnedBOP>smallBOP>largeBOP + //L2 Streamstride>SMS>BOP>TP + //first we get 1 L1PF + bool L1PFsent = false; + if (stridestream_pfFilter_l1.hasPFRequestsInBuffer()){ + L1PFsent = stridestream_pfFilter_l1.GetPFAddrL1(addresses); + } + if (!L1PFsent && berti->hasPFRequestsInBuffer()){ + L1PFsent = berti->GetPFRequestsFromBuffer(addresses); + } + if(!L1PFsent && sms_pfFilter.hasPFRequestsInBuffer()){ + L1PFsent = sms_pfFilter.GetPFAddrL1(addresses); + } + if(!L1PFsent && cmc->hasPFRequestsInBuffer()){ + L1PFsent = cmc->GetPFRequestsFromBuffer(addresses); + } + if (BOPPFlevel == 1 && !L1PFsent && learnedBOP->hasPFRequestsInBuffer()){ + L1PFsent = learnedBOP->GetPFRequestsFromBuffer(addresses); + } + if (BOPPFlevel == 1 && !L1PFsent && smallBOP->hasPFRequestsInBuffer()){ + L1PFsent = smallBOP->GetPFRequestsFromBuffer(addresses); + } + if (BOPPFlevel == 1 && !L1PFsent && largeBOP->hasPFRequestsInBuffer()){ + L1PFsent = largeBOP->GetPFRequestsFromBuffer(addresses); + } + if (!L1PFsent && spp->hasPFRequestsInBuffer()){ + L1PFsent = spp->GetPFRequestsFromBuffer(addresses); + } + if (!L1PFsent && ipcp->hasPFRequestsInBuffer()){ + L1PFsent = ipcp->GetPFRequestsFromBuffer(addresses); + } + if (!L1PFsent && Opt->hasPFRequestsInBuffer()){ + L1PFsent = Opt->GetPFRequestsFromBuffer(addresses); + } + bool L2PFsent = false; + if (stridestream_pfFilter_l2l3.hasPFRequestsInBuffer()){ + L2PFsent = stridestream_pfFilter_l2l3.GetPFAddrL2(addresses); + } + if (!L2PFsent && sms_pfFilter.hasPFRequestsInBuffer()){ + L2PFsent = sms_pfFilter.GetPFAddrL2(addresses); + } + if (BOPPFlevel == 2 && !L2PFsent && largeBOP->hasPFRequestsInBuffer()){ + L2PFsent = largeBOP->GetPFRequestsFromBuffer(addresses); + addresses.back().pfahead_host = 2; + addresses.back().pfahead = true; + } + if (BOPPFlevel == 2 && !L2PFsent && smallBOP->hasPFRequestsInBuffer()){ + L2PFsent = smallBOP->GetPFRequestsFromBuffer(addresses); + addresses.back().pfahead_host = 2; + addresses.back().pfahead = true; + } + if (BOPPFlevel == 2 && !L2PFsent && learnedBOP->hasPFRequestsInBuffer()){ + L2PFsent = learnedBOP->GetPFRequestsFromBuffer(addresses); + addresses.back().pfahead_host = 2; + addresses.back().pfahead = true; + } + bool L3PFsent = false; + L3PFsent = stridestream_pfFilter_l2l3.GetPFAddrL3(addresses); + if (!L3PFsent && sms_pfFilter.hasPFRequestsInBuffer()){ + L3PFsent = sms_pfFilter.GetPFAddrL3(addresses); + } + return L1PFsent || L2PFsent || L3PFsent; +} +bool XSCompositePrefetcher::hasPFRequestsInBuffer() { + return sms_pfFilter.hasPFRequestsInBuffer() || + stridestream_pfFilter_l1.hasPFRequestsInBuffer() || + stridestream_pfFilter_l2l3.hasPFRequestsInBuffer() || + largeBOP->hasPFRequestsInBuffer() || + smallBOP->hasPFRequestsInBuffer() || + learnedBOP->hasPFRequestsInBuffer() || + berti->hasPFRequestsInBuffer() || + cmc->hasPFRequestsInBuffer() || + spp->hasPFRequestsInBuffer() || + ipcp->hasPFRequestsInBuffer() || + Opt->hasPFRequestsInBuffer() ; +} +void +XSCompositePrefetcher::phtSendEventWrapper(){ + for(int i=0; i<3; i++){ + if (phtSentPrefetch[i].valid){ + sms_pfFilter.Insert(phtSentPrefetch[i].region_addr, phtSentPrefetch[i].region_bits, + phtSentPrefetch[i].alias_bits,phtSentPrefetch[i].paddr_valid, phtSentPrefetch[i].decr_mode, + phtSentPrefetch[i].is_secure,phtSentPrefetch[i].PFlevel, &phtSentPrefetch[i].trigger); + phtSentPrefetch[i].valid = false; + break; + } + } + if (!phtReqSendEvent.scheduled()){ + if(phtSentPrefetch[0].valid || phtSentPrefetch[1].valid || phtSentPrefetch[2].valid) + schedule(phtReqSendEvent, nextCycle()); + } + +} } // prefetch } // gem5 diff --git a/src/mem/cache/prefetch/sms.hh b/src/mem/cache/prefetch/sms.hh index 64b03aa121..df04c36fc7 100644 --- a/src/mem/cache/prefetch/sms.hh +++ b/src/mem/cache/prefetch/sms.hh @@ -6,6 +6,8 @@ #define GEM5_SMS_HH #include +#include +#include #include @@ -30,17 +32,26 @@ namespace gem5 { struct XSCompositePrefetcherParams; +class BaseIndexingPolicy; +namespace replacement_policy { class Base; } GEM5_DEPRECATED_NAMESPACE(Prefetcher, prefetch); namespace prefetch { +// PrefetchFilter is implemented in its own header/source to keep sms. +#include "mem/cache/prefetch/prefetch_filter.hh" + + class XSCompositePrefetcher : public Queued { protected: const unsigned int regionSize; const unsigned int regionBlks; + const bool enableTrainFilter; // Enable TrainFilter for ROB-order training + + bool useTrainingBuffer() const override { return enableTrainFilter; } Addr regionAddress(Addr a) { return a / regionSize; }; @@ -109,8 +120,9 @@ class XSCompositePrefetcher : public Queued public: std::vector hist; Addr pc; + bool decr_mode; PhtEntry(const size_t sz, const SatCounter8 &conf) - : TaggedEntry(), hist(sz, conf) + : TaggedEntry(), hist(sz, conf), decr_mode(false) { } }; @@ -135,6 +147,12 @@ class XSCompositePrefetcher : public Queued statistics::Scalar actMNum; statistics::Scalar refillNotifyCount; statistics::Scalar bopTrainCount; + statistics::Scalar smsCurRegionoverride; + statistics::Scalar smsIncrRegionoverride; + statistics::Scalar smsDecrRegionoverride; + statistics::Scalar strideTrainCount; + statistics::Scalar streamTrainCount; + statistics::Scalar totalTrainCount; } stats; public: @@ -157,6 +175,10 @@ class XSCompositePrefetcher : public Queued const unsigned pfPageFilterSize{16}; boost::compute::detail::lru_cache pfBlockLRUFilter; + PrefetchFilter sms_pfFilter; + PrefetchFilter stridestream_pfFilter_l1; + PrefetchFilter stridestream_pfFilter_l2l3; + boost::compute::detail::lru_cache pfPageLRUFilter; boost::compute::detail::lru_cache pfPageLRUFilterL2; boost::compute::detail::lru_cache pfPageLRUFilterL3; @@ -206,6 +228,44 @@ class XSCompositePrefetcher : public Queued } } void setParentInfo(System *sys, ProbeManager *pm, CacheAccessor* _cache, unsigned blk_size) override; + + protected: + using TriggerInfo = Base::PFtriggerInfo; + struct phtsentInfo { + bool valid; + Addr region_addr; + uint64_t region_bits; + uint8_t alias_bits; + bool paddr_valid; + bool decr_mode; + bool is_secure; + uint64_t PFlevel; + TriggerInfo trigger; + // phtsentInfo() + // : valid(false), region_addr(0), region_bits(0), alias_bits(0), paddr_valid(false), + // decr_mode(false), is_secure(false), PFlevel(0), trigger() {}; + phtsentInfo(Addr region_addr = 0, uint64_t region_bits = 0, uint8_t alias_bits = 0, + bool paddr_valid = false, bool decr_mode = false, + bool is_secure = false, uint64_t PFlevel = 0, + const TriggerInfo *trigger = nullptr) + : valid(true), region_addr(region_addr), region_bits(region_bits), alias_bits(alias_bits), + paddr_valid(paddr_valid), decr_mode(decr_mode), is_secure(is_secure), + PFlevel(PFlevel), trigger(trigger == nullptr ? TriggerInfo() : *trigger) {}; + ~phtsentInfo() = default; + }; + std::vector phtSentPrefetch;//0 cur ,1 inc ,2 dec + /** Event to handle the pht sending */ + void phtSendEventWrapper(); + EventFunctionWrapper phtReqSendEvent; + protected: + void InsertPFRequestToBuffer(const AddrPriority &addr_prio) override{ + panic("SMS:InsertPFRequestToBuffer not implemented"); + }; + public: + bool GetPFRequestsFromBuffer(std::vector &addresses) override; + bool hasPFRequestsInBuffer() override; + protected: + const int BOPPFlevel; }; } diff --git a/src/mem/cache/prefetch/worker.cc b/src/mem/cache/prefetch/worker.cc index aa3bd92cfa..6602ebe7a0 100644 --- a/src/mem/cache/prefetch/worker.cc +++ b/src/mem/cache/prefetch/worker.cc @@ -79,6 +79,7 @@ WorkerPrefetcher::transfer() } dpp_it = localBuffer.erase(dpp_it); count++; + latestTransferTick = curTick(); } schedule(transferEvent, nextCycle()); } diff --git a/src/mem/cache/prefetch/worker.hh b/src/mem/cache/prefetch/worker.hh index f94cca02f7..c9547cab50 100644 --- a/src/mem/cache/prefetch/worker.hh +++ b/src/mem/cache/prefetch/worker.hh @@ -85,6 +85,7 @@ class WorkerPrefetcher : public Queued std::list localBuffer; unsigned depth{4}; + Tick latestTransferTick{0}; }; } // namespace prefetch diff --git a/src/mem/cache/prefetch/xs_stream.cc b/src/mem/cache/prefetch/xs_stream.cc index 7dba1b8b7c..8d414907ba 100644 --- a/src/mem/cache/prefetch/xs_stream.cc +++ b/src/mem/cache/prefetch/xs_stream.cc @@ -10,6 +10,8 @@ namespace prefetch XsStreamPrefetcher::XsStreamPrefetcher(const XsStreamPrefetcherParams &p) : Queued(p), + regionSize(p.region_size), + regionBlks(p.region_size / p.block_size), depth(p.xs_stream_depth), badPreNum(0), enableAutoDepth(p.enable_auto_depth), @@ -52,14 +54,14 @@ XsStreamPrefetcher::calculatePrefetch(const PrefetchInfo &pfi, std::vector &addresses, - int prio, PrefetchSourceType src, int pf_degree, int ahead_level) + int prio, PrefetchSourceType src, int pf_degree, int ahead_level, STREAMEntry *entry) { + uint64_t region_bit = 0; for (int i = 0; i < pf_degree; i++) { Addr pf_addr = addr + i * blkSize; + region_bit |= (uint64_t(1) << regionOffset(pf_addr)); + + // Count generated prefetch + prefetchStats.pfGenerated++; + if (filter->contains(pf_addr)) { DPRINTF(XsStreamPrefetcher, "Skip recently prefetched: %lx\n", pf_addr); + // Count filtered prefetch + prefetchStats.pfFiltered++; } else { DPRINTF(XsStreamPrefetcher, "Send pf: %lx\n", pf_addr); filter->insert(pf_addr, 0); @@ -130,6 +140,12 @@ XsStreamPrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::ve } } } + pfi.setTriggerInfo_PFsrc(src); + if (ahead_level > 1) { + stridestream_pfFilter_l2l3->Insert(regionAddress(addr), region_bit,0,true,entry->decrMode,pfi.isSecure(),ahead_level, &pfi.trigger_info); + } else { + stridestream_pfFilter_l1->Insert(regionAddress(addr), region_bit,0,true,entry->decrMode,pfi.isSecure(),ahead_level, &pfi.trigger_info); + } } diff --git a/src/mem/cache/prefetch/xs_stream.hh b/src/mem/cache/prefetch/xs_stream.hh index f38cdb00da..0c200a94ca 100644 --- a/src/mem/cache/prefetch/xs_stream.hh +++ b/src/mem/cache/prefetch/xs_stream.hh @@ -10,10 +10,10 @@ #include "base/types.hh" #include "debug/XsStreamPrefetcher.hh" #include "mem/cache/prefetch/associative_set.hh" -#include "mem/cache/prefetch/queued.hh" +// #include "mem/cache/prefetch/queued.hh" #include "mem/packet.hh" #include "params/XsStreamPrefetcher.hh" - +#include "mem/cache/prefetch/prefetch_filter.hh" namespace gem5 { struct XsStreamPrefetcherParams; @@ -22,6 +22,14 @@ namespace prefetch { class XsStreamPrefetcher : public Queued { + protected: + const unsigned int regionSize; + const unsigned int regionBlks; + + + Addr regionAddress(Addr a) { return a / regionSize; }; + + Addr regionOffset(Addr a) { return (a / blkSize) % regionBlks; } protected: int depth; int badPreNum; @@ -80,7 +88,7 @@ class XsStreamPrefetcher : public Queued AssociativeSet stream_array; STREAMEntry *streamLookup(const PrefetchInfo &pfi, bool &in_active_page, bool &decr); void sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, - PrefetchSourceType src, int pf_degree, int ahead_level = -1); + PrefetchSourceType src, int pf_degree, int ahead_level = -1, STREAMEntry *entry = nullptr); public: boost::compute::detail::lru_cache *filter; @@ -93,6 +101,8 @@ class XsStreamPrefetcher : public Queued panic("not implemented"); }; void calculatePrefetch(const PrefetchInfo &pfi, std::vector &addresses, int late_num); + PrefetchFilter* stridestream_pfFilter_l1; + PrefetchFilter* stridestream_pfFilter_l2l3; }; } } diff --git a/src/mem/cache/prefetch/xs_stride.cc b/src/mem/cache/prefetch/xs_stride.cc index 5fbc91e76b..0474459938 100644 --- a/src/mem/cache/prefetch/xs_stride.cc +++ b/src/mem/cache/prefetch/xs_stride.cc @@ -3,6 +3,7 @@ #include "mem/cache/prefetch/xs_stride.hh" +#include "base/stats/group.hh" #include "debug/XSStridePrefetcher.hh" #include "mem/cache/prefetch/associative_set_impl.hh" @@ -12,16 +13,20 @@ namespace prefetch { XSStridePrefetcher::XSStridePrefetcher(const XSStridePrefetcherParams &p) - : Queued(p),useXsDepth(p.use_xs_depth),fuzzyStrideMatching(p.fuzzy_stride_matching), + : Queued(p),useXsDepth(p.use_xs_depth),useRedundantTable(p.use_redundant_table), + fuzzyStrideMatching(p.fuzzy_stride_matching), shortStrideThres(p.short_stride_thres), strideDynDepth(p.stride_dyn_depth), enableNonStrideFilter(p.enable_non_stride_filter), - strideUnique(p.stride_entries, p.stride_entries, p.stride_indexing_policy, - p.stride_replacement_policy, StrideEntry()), - strideRedundant(p.stride_entries, p.stride_entries, p.stride_indexing_policy, - p.stride_replacement_policy, StrideEntry()), + regionSize(p.region_size), + regionBlks(p.region_size / p.block_size), + strideUnique(p.stride_entries, p.stride_entries, p.stride_unique_indexing_policy, + p.stride_unique_replacement_policy, StrideEntry()), + strideRedundant(p.stride_entries, p.stride_entries, p.stride_redundant_indexing_policy, + p.stride_redundant_replacement_policy, StrideEntry()), nonStridePCs(p.non_stride_assoc, p.non_stride_entries, p.non_stride_indexing_policy, - p.non_stride_replacement_policy, NonStrideEntry()) + p.non_stride_replacement_policy, NonStrideEntry()), + stats(this) { } @@ -31,22 +36,27 @@ XSStridePrefetcher::calculatePrefetch(const PrefetchInfo &pfi, std::vector &stride, const PrefetchInfo &pfi, std::vector &addresses, bool late, Addr &stride_pf, PrefetchSourceType last_pf_source, bool enter_new_region, bool miss_repeat, - int64_t &learned_bop_offset) + int64_t &learned_bop_offset, bool is_first_shot) { + if (is_first_shot) { + stats.strideUniquequeryCount++; + } else { + stats.strideRedundantqueryCount++; + } Addr lookupAddr = pfi.getAddr(); Addr stride_hash_pc = strideHashPc(pfi.getPC()); StrideEntry *entry = stride.findEntry(stride_hash_pc, pfi.isSecure()); @@ -56,6 +66,22 @@ XSStridePrefetcher::strideLookup(AssociativeSet &stride, const Pref miss_repeat); bool should_cover = false; if (entry) { + if (archDBer){ + archDBer->strideTraceWrite(curTick(), lookupAddr, pfi.getPC(), stride_hash_pc, + true, is_first_shot, pfi.isCacheMiss(), true); + } + }else{ + if (archDBer){ + archDBer->strideTraceWrite(curTick(), lookupAddr, pfi.getPC(), stride_hash_pc, + false, is_first_shot, pfi.isCacheMiss(), true); + } + } + if (entry) { + if (is_first_shot) { + stats.strideUniquehitCount++; + } else { + stats.strideRedundanthitCount++; + } stride.accessEntry(entry); int64_t new_stride = lookupAddr - entry->lastAddr; if (new_stride == 0 || (labs(new_stride) < 64 && (miss_repeat || entry->longStride.calcSaturation() >= 0.5))) { @@ -155,11 +181,27 @@ XSStridePrefetcher::strideLookup(AssociativeSet &stride, const Pref PrefetchSourceType::SStride, 1); sendPFWithFilter(pfi, blockAddress(lookupAddr + (entry->stride << 5)), addresses, 0, PrefetchSourceType::SStride, 2); + if (is_first_shot) { + stats.strideUniquepfCount += 2; + } else { + stats.strideRedundantpfCount += 2; + } + if (archDBer){ + archDBer->strideTraceWrite(curTick(), blockAddress(lookupAddr + (entry->stride << 2)), pfi.getPC(), stride_hash_pc, + true, is_first_shot, pfi.isCacheMiss(), false); + archDBer->strideTraceWrite(curTick(), blockAddress(lookupAddr + (entry->stride << 5)), pfi.getPC(), stride_hash_pc, + true, is_first_shot, pfi.isCacheMiss(), false); + } } else { for (unsigned i = start_depth; i <= entry->depth; i++) { pf_addr = lookupAddr + entry->stride * i; DPRINTF(XSStridePrefetcher, "Stride conf >= 2, send pf: %x with depth %i\n", pf_addr, i); sendPFWithFilter(pfi, blockAddress(pf_addr), addresses, 0, PrefetchSourceType::SStride, 1); + if (is_first_shot) { + stats.strideUniquepfCount++; + } else { + stats.strideRedundantpfCount++; + } } stride_pf = pf_addr; // the longest lookahead } @@ -167,6 +209,11 @@ XSStridePrefetcher::strideLookup(AssociativeSet &stride, const Pref should_cover = true; } } else { + if (is_first_shot) { + stats.strideUniquemissCount++; + } else { + stats.strideRedundantmissCount++; + } DPRINTF(XSStridePrefetcher, "Stride miss, insert it\n"); entry = stride.findVictim(0); DPRINTF(XSStridePrefetcher, "Stride found victim pc = %x, stride = %i\n", entry->pc, entry->stride); @@ -175,6 +222,13 @@ XSStridePrefetcher::strideLookup(AssociativeSet &stride, const Pref maxHistStrides - 1, entry->pc); markNonStridePC(entry->pc); } + if (entry->conf >= 2){ + if (is_first_shot) { + stats.strideUniquereplaceusefulCount++; + } else { + stats.strideRedundantreplaceusefulCount++; + } + } if (entry->conf >= 2 && entry->stride > 1024) { // > 1k DPRINTF(XSStridePrefetcher, "Stride Evicting a useful stride, send it to BOP with offset %i\n", entry->stride / 64); @@ -239,9 +293,15 @@ void XSStridePrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::vector &addresses, int prio, PrefetchSourceType src, int ahead_level) { + // Count generated prefetch + prefetchStats.pfGenerated++; + pfi.setTriggerInfo_PFsrc(src); if (ahead_level > 1){ + stridestream_pfFilter_l2l3->Insert(regionAddress(addr), uint64_t(1) << regionOffset(addr),0,true,false,pfi.isSecure(),ahead_level, &pfi.trigger_info); if (filterL2->contains(addr)) { DPRINTF(XSStridePrefetcher, "Skip recently prefetched: %lx\n", addr); + // Count filtered prefetch + prefetchStats.pfFiltered++; } else { DPRINTF(XSStridePrefetcher, "Send pf: %lx\n", addr); filterL2->insert(addr, 0); @@ -251,8 +311,11 @@ XSStridePrefetcher::sendPFWithFilter(const PrefetchInfo &pfi, Addr addr, std::ve addresses.back().pfahead = true; } } else { + stridestream_pfFilter_l1->Insert(regionAddress(addr), uint64_t(1) << regionOffset(addr),0,true,false,pfi.isSecure(),ahead_level, &pfi.trigger_info); if (filter->contains(addr)) { DPRINTF(XSStridePrefetcher, "Skip recently prefetched: %lx\n", addr); + // Count filtered prefetch + prefetchStats.pfFiltered++; } else { DPRINTF(XSStridePrefetcher, "Send pf: %lx\n", addr); filter->insert(addr, 0); @@ -273,6 +336,21 @@ XSStridePrefetcher::strideHashPc(Addr pc) return (pc_high << 10) | pc_low; } +XSStridePrefetcher::XSstrideStats::XSstrideStats(statistics::Group *parent) + : statistics::Group(parent), + ADD_STAT(strideUniquequeryCount, statistics::units::Count::get(), "stride table query num"), + ADD_STAT(strideUniquehitCount, statistics::units::Count::get(), "stride table hit num"), + ADD_STAT(strideUniquemissCount, statistics::units::Count::get(), "stride table miss num"), + ADD_STAT(strideUniquepfCount, statistics::units::Count::get(), "stride prefetch num"), + ADD_STAT(strideUniquereplaceusefulCount, statistics::units::Count::get(), "stride table replace num"), + ADD_STAT(strideRedundantqueryCount, statistics::units::Count::get(), "stride table query num"), + ADD_STAT(strideRedundanthitCount, statistics::units::Count::get(), "stride table hit num"), + ADD_STAT(strideRedundantmissCount, statistics::units::Count::get(), "stride table miss num"), + ADD_STAT(strideRedundantpfCount, statistics::units::Count::get(), "stride prefetch num"), + ADD_STAT(strideRedundantreplaceusefulCount, statistics::units::Count::get(), "stride table replace num") +{ +} + } } diff --git a/src/mem/cache/prefetch/xs_stride.hh b/src/mem/cache/prefetch/xs_stride.hh index 5af9fbe96d..9b1af4c6fb 100644 --- a/src/mem/cache/prefetch/xs_stride.hh +++ b/src/mem/cache/prefetch/xs_stride.hh @@ -14,10 +14,10 @@ #include "base/types.hh" #include "debug/XSStridePrefetcher.hh" #include "mem/cache/prefetch/associative_set.hh" -#include "mem/cache/prefetch/queued.hh" +// #include "mem/cache/prefetch/queued.hh" #include "mem/packet.hh" #include "params/XSStridePrefetcher.hh" - +#include "mem/cache/prefetch/prefetch_filter.hh" namespace gem5 { @@ -32,11 +32,19 @@ class XSStridePrefetcher : public Queued { protected: const bool useXsDepth; + const bool useRedundantTable; const bool fuzzyStrideMatching; const unsigned shortStrideThres; const bool strideDynDepth{false}; const bool enableNonStrideFilter; + protected: + const unsigned int regionSize; + const unsigned int regionBlks; + + + Addr regionAddress(Addr a) { return a / regionSize; }; + Addr regionOffset(Addr a) { return (a / blkSize) % regionBlks; } class StrideEntry : public TaggedEntry { @@ -74,7 +82,7 @@ class XSStridePrefetcher : public Queued bool strideLookup(AssociativeSet &stride, const PrefetchInfo &pfi, std::vector &address, bool late, Addr &pf_addr, PrefetchSourceType src, bool enter_new_region, bool miss_repeat, - int64_t &learned_bop_offset); + int64_t &learned_bop_offset, bool is_first_shot); AssociativeSet strideUnique; @@ -115,6 +123,24 @@ class XSStridePrefetcher : public Queued void calculatePrefetch(const PrefetchInfo &pfi, std::vector &addresses, bool late, PrefetchSourceType pf_source, bool miss_repeat, bool enter_new_region, bool is_first_shot, Addr &pf_addr, int64_t &learned_bop_offset); + PrefetchFilter* stridestream_pfFilter_l1; + PrefetchFilter* stridestream_pfFilter_l2l3; + + struct XSstrideStats : public statistics::Group + { + XSstrideStats(statistics::Group *parent); + statistics::Scalar strideUniquequeryCount; + statistics::Scalar strideUniquehitCount; + statistics::Scalar strideUniquemissCount; + statistics::Scalar strideUniquepfCount; + statistics::Scalar strideUniquereplaceusefulCount; + statistics::Scalar strideRedundantqueryCount; + statistics::Scalar strideRedundanthitCount; + statistics::Scalar strideRedundantmissCount; + statistics::Scalar strideRedundantpfCount; + statistics::Scalar strideRedundantreplaceusefulCount; + + } stats; }; } diff --git a/src/sim/ArchDBer.py b/src/sim/ArchDBer.py index 56eda5ec86..d0584b07cf 100644 --- a/src/sim/ArchDBer.py +++ b/src/sim/ArchDBer.py @@ -59,6 +59,8 @@ class ArchDBer(SimObject): dump_l1_miss_trace = Param.Bool(False, "Dump l1 miss trace") dump_bop_train_trace = Param.Bool(False, "Dump bop train trace") dump_sms_train_trace = Param.Bool(False, "Dump sms train trace") + dump_stride_train_trace = Param.Bool(False, "Dump stride train trace") + dump_despacito_train_trace = Param.Bool(False, "Dump despacito train trace") dump_l1d_way_pre_trace = Param.Bool(False, "Dump l1d way predction trace") dump_vaddr_trace = Param.Bool(False, "Dump vaddr trace") dump_lifetime = Param.Bool(False, "Dump inst lifetime") diff --git a/src/sim/arch_db.cc b/src/sim/arch_db.cc index 883033d036..e78cd8e86d 100644 --- a/src/sim/arch_db.cc +++ b/src/sim/arch_db.cc @@ -27,6 +27,8 @@ ArchDBer::ArchDBer(const Params &p) dumpL1MissTrace(p.dump_l1_miss_trace), dumpBopTrainTrace(p.dump_bop_train_trace), dumpSMSTrainTrace(p.dump_sms_train_trace), + dumpStrideTrainTrace(p.dump_stride_train_trace), + dumpDespacitoTrainTrace(p.dump_despacito_train_trace), dumpL1WayPreTrace(p.dump_l1d_way_pre_trace), dumpVaddrTrace(p.dump_vaddr_trace), dumpLifetime(p.dump_lifetime), @@ -169,7 +171,36 @@ ArchDBer::smsTrainTraceWrite(Tick tick, Addr old_addr, Addr cur_addr, Addr trigg fatal("SQL error: %s\n", zErrMsg); }; } +void +ArchDBer::strideTraceWrite(Tick tick, Addr addr, Addr PC, Addr hashPC, bool hit, bool isFirstShot, bool miss, bool is_train) +{ + bool dump_me = dumpGlobal && dumpStrideTrainTrace; + if (!dump_me) return; + sprintf(memTraceSQLBuf, + "INSERT INTO StrideTrainTrace(Tick,Addr,PC,HashPC,QueryHit,IsFirstShot,Miss,IsTrain,SITE) " + "VALUES(%ld,%ld,%ld,%ld,%d,%d,%d,%d,'%s');", + tick, addr, PC, hashPC, hit, isFirstShot, miss, is_train, "StrideTrain"); + rc = sqlite3_exec(mem_db, memTraceSQLBuf, callback, 0, &zErrMsg); + if (rc != SQLITE_OK) { + fatal("SQL error: %s\n", zErrMsg); + }; +} +void +ArchDBer::despacitoTraceWrite(Tick tick, Addr vaddr, Addr paddr, Addr PC, bool hasPC, bool miss, bool is_train) +{ + bool dump_me = dumpGlobal && dumpDespacitoTrainTrace; + if (!dump_me) return; + + sprintf(memTraceSQLBuf, + "INSERT INTO DespacitoTrainTrace(Tick,vAddr,pAddr,PC,hasPC,Miss,IsTrain,SITE) " + "VALUES(%ld,%ld,%ld,%ld,%d,%d,%d,'%s');", + tick, vaddr, paddr, PC, hasPC, miss, is_train, is_train?"DespacitoTrain":"DespacitoPrefetch"); + rc = sqlite3_exec(mem_db, memTraceSQLBuf, callback, 0, &zErrMsg); + if (rc != SQLITE_OK) { + fatal("SQL error: %s\n", zErrMsg); + }; +} void ArchDBer::L1MissTrace_write( uint64_t pc, uint64_t source, diff --git a/src/sim/arch_db.hh b/src/sim/arch_db.hh index 88d7035a1d..2ce14618cc 100644 --- a/src/sim/arch_db.hh +++ b/src/sim/arch_db.hh @@ -60,6 +60,8 @@ class ArchDBer : public SimObject bool dumpL1MissTrace; bool dumpBopTrainTrace; bool dumpSMSTrainTrace; + bool dumpStrideTrainTrace; + bool dumpDespacitoTrainTrace; bool dumpL1WayPreTrace; bool dumpVaddrTrace; bool dumpLifetime; @@ -101,6 +103,8 @@ class ArchDBer : public SimObject void bopTrainTraceWrite(Tick tick, Addr old_addr, Addr cur_addr, Addr offset, int score, bool miss); void smsTrainTraceWrite(Tick tick, Addr old_addr, Addr cur_addr, Addr trigger_offset, int conf, bool miss); + void strideTraceWrite(Tick tick, Addr addr, Addr PC, Addr hashPC, bool hit, bool isFirstShot, bool miss, bool is_train); + void despacitoTraceWrite(Tick tick, Addr vaddr, Addr paddr, Addr PC, bool hasPC, bool miss, bool is_train); void dcacheWayPreTrace(Tick tick, uint64_t pc, uint64_t vaddr, int way, int is_write); void vaddrTrace(Tick tick, uint64_t pc, uint64_t vaddr, int hit); char memTraceSQLBuf[1024]; diff --git a/util/parse_stride_trace.py b/util/parse_stride_trace.py new file mode 100644 index 0000000000..f241a37423 --- /dev/null +++ b/util/parse_stride_trace.py @@ -0,0 +1,212 @@ +#!/usr/bin/env python3 +"""Split StrideTrainTrace into training vs prediction CSVs, HTML tables, and a quick plot. + +Usage: + python util/parse_stride_trace.py \\ + --db $HOME/Trace.db \\ + --out-train train_stride.csv \\ + --out-predict predict_stride.csv \\ + --html-train train_stride.html \\ + --html-predict predict_stride.html \\ + --plot stride_trace.png --plot-limit 8000 + +The script streams results to avoid loading the whole table into memory. +""" + +import argparse +import csv +import html +import os +import sqlite3 +from typing import Iterable, List, Sequence + +try: + from matplotlib import pyplot as plt + from matplotlib import ticker + _HAS_MPL = True +except ImportError: + _HAS_MPL = False + plt = None + ticker = None + +DEFAULT_DB = ( + "/nfs/home/changqingyu/gem5_run_res/Correlation/" + "xsidealfetch_newPFalign_260108_kmhv3_removeL2Filter_counter/" + "GemsFDTD_30385_0.268180/Trace.db" +) + + +def get_columns(conn: sqlite3.Connection) -> List[str]: + """Return column names from StrideTrainTrace in declared order.""" + cur = conn.execute("PRAGMA table_info('StrideTrainTrace')") + cols = [row[1] for row in cur.fetchall()] + if not cols: + raise RuntimeError("StrideTrainTrace table not found in DB") + return cols + + +def _is_hex_column(col: str) -> bool: + return col.lower() in {"addr", "pc", "hashpc"} + + +def _format_cell(col: str, val) -> str: + if val is None: + return "" + if _is_hex_column(col): + try: + return hex(int(val)) + except (ValueError, TypeError): + return str(val) + return str(val) + + +def write_subset(conn: sqlite3.Connection, where_clause: str, out_path: str, columns: List[str]) -> int: + """Write subset matching where_clause to CSV, return written row count.""" + os.makedirs(os.path.dirname(os.path.abspath(out_path)) or '.', exist_ok=True) + query = f"SELECT {', '.join(columns)} FROM StrideTrainTrace WHERE {where_clause} ORDER BY Tick" + count = 0 + with open(out_path, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + writer.writerow(columns) + for row in conn.execute(query): + writer.writerow(row) + count += 1 + return count + + +def write_subset_html( + conn: sqlite3.Connection, + where_clause: str, + out_path: str, + columns: List[str], + title: str, +) -> int: + """Write subset as an HTML table, return written row count.""" + os.makedirs(os.path.dirname(os.path.abspath(out_path)) or '.', exist_ok=True) + query = f"SELECT {', '.join(columns)} FROM StrideTrainTrace WHERE {where_clause} ORDER BY Tick" + count = 0 + with open(out_path, "w") as f: + f.write("\n\n") + f.write("StrideTrainTrace - " + html.escape(title) + "\n") + f.write( + "\n" + ) + f.write("

" + html.escape(title) + "

\n") + f.write("
" + "".join( + f"" + for idx, col in enumerate(columns) + ) + "
\n") + f.write("" + "".join(f"" for col in columns) + "\n") + for row in conn.execute(query): + f.write("") + for col, cell in zip(columns, row): + f.write(f"") + f.write("\n") + count += 1 + f.write("
{html.escape(col)}
{html.escape(_format_cell(col, cell))}
\n") + f.write(f"

Total rows: {count}

\n") + f.write( + "\n" + ) + f.write("") + return count + + +def _fetch_for_plot(conn: sqlite3.Connection, limit: int) -> Sequence[Sequence[int]]: + query = "SELECT Tick, Addr, IsTrain FROM StrideTrainTrace ORDER BY Tick" + if limit > 0: + query += f" LIMIT {int(limit)}" + data = conn.execute(query).fetchall() + ticks_train, addrs_train = [], [] + ticks_pred, addrs_pred = [], [] + for tick, addr, is_train in data: + if is_train: + ticks_train.append(int(tick)) + addrs_train.append(int(addr)) + else: + ticks_pred.append(int(tick)) + addrs_pred.append(int(addr)) + return ticks_train, addrs_train, ticks_pred, addrs_pred + + +def plot_trace(conn: sqlite3.Connection, out_path: str, limit: int) -> None: + if not _HAS_MPL: + raise RuntimeError("matplotlib is required for plotting; install it or omit --plot") + ticks_train, addrs_train, ticks_pred, addrs_pred = _fetch_for_plot(conn, limit) + if not ticks_train and not ticks_pred: + print("No rows to plot; skipping plot generation") + return + + fig, ax = plt.subplots(figsize=(10, 5)) + if ticks_train: + ax.scatter(ticks_train, addrs_train, s=6, c="#1f77b4", label="IsTrain=1") + if ticks_pred: + ax.scatter(ticks_pred, addrs_pred, s=6, c="#d62728", label="IsTrain=0") + ax.set_xlabel("Tick") + ax.set_ylabel("Addr (hex)") + ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, _: hex(int(x)))) + ax.legend(loc="best") + ax.set_title("StrideTrainTrace Addr vs Tick") + fig.tight_layout() + fig.savefig(out_path, dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Split StrideTrainTrace into train/predict CSV files, HTML tables, and plot") + parser.add_argument("--db", default=DEFAULT_DB, help="Path to Trace.db containing StrideTrainTrace") + parser.add_argument("--out-train", default="stride_train.csv", help="Output CSV for IsTrain=1 rows") + parser.add_argument( + "--out-predict", + default="stride_predict.csv", + help="Output CSV for IsTrain=0 rows (predicted addresses)", + ) + parser.add_argument("--html-train", default="stride_train.html", help="Output HTML table for IsTrain=1 rows") + parser.add_argument( + "--html-predict", + default="stride_predict.html", + help="Output HTML table for IsTrain=0 rows", + ) + parser.add_argument("--plot", default=None, help="Optional PNG output for Addr vs Tick scatter") + parser.add_argument("--plot-limit", type=int, default=8000, help="Max rows to load for plotting (0 = all)") + args = parser.parse_args() + + if not os.path.exists(args.db): + raise FileNotFoundError(f"DB not found: {args.db}") + + with sqlite3.connect(args.db) as conn: + cols = get_columns(conn) + train_rows = write_subset(conn, "IsTrain = 1", args.out_train, cols) + predict_rows = write_subset(conn, "IsTrain = 0", args.out_predict, cols) + train_rows_html = write_subset_html(conn, "IsTrain = 1", args.html_train, cols, "Stride Training Trace") + predict_rows_html = write_subset_html(conn, "IsTrain = 0", args.html_predict, cols, "Stride Prediction Trace") + if args.plot: + plot_trace(conn, args.plot, args.plot_limit) + + print(f"Wrote {train_rows} training rows to {args.out_train}") + print(f"Wrote {predict_rows} prediction rows to {args.out_predict}") + print(f"Wrote {train_rows_html} training rows to {args.html_train}") + print(f"Wrote {predict_rows_html} prediction rows to {args.html_predict}") + if args.plot: + print(f"Wrote plot to {args.plot} (limit={args.plot_limit})") + + +if __name__ == "__main__": + main()