@@ -312,107 +312,6 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
312312 return chiSquared;
313313 }
314314
315- // 90pc threshold
316- ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passPT3RPhiChiSquaredCuts (ModulesConst modules,
317- uint16_t lowerModuleIndex1,
318- uint16_t lowerModuleIndex2,
319- uint16_t lowerModuleIndex3,
320- float chiSquared) {
321- const int layer1 =
322- modules.layers ()[lowerModuleIndex1] + 6 * (modules.subdets ()[lowerModuleIndex1] == Endcap) +
323- 5 * (modules.subdets ()[lowerModuleIndex1] == Endcap and modules.moduleType ()[lowerModuleIndex1] == TwoS);
324- const int layer2 =
325- modules.layers ()[lowerModuleIndex2] + 6 * (modules.subdets ()[lowerModuleIndex2] == Endcap) +
326- 5 * (modules.subdets ()[lowerModuleIndex2] == Endcap and modules.moduleType ()[lowerModuleIndex2] == TwoS);
327- const int layer3 =
328- modules.layers ()[lowerModuleIndex3] + 6 * (modules.subdets ()[lowerModuleIndex3] == Endcap) +
329- 5 * (modules.subdets ()[lowerModuleIndex3] == Endcap and modules.moduleType ()[lowerModuleIndex3] == TwoS);
330-
331- if (layer1 == 8 and layer2 == 9 and layer3 == 10 ) {
332- return chiSquared < 7 .003f ;
333- } else if (layer1 == 8 and layer2 == 9 and layer3 == 15 ) {
334- return chiSquared < 0 .5f ;
335- } else if (layer1 == 7 and layer2 == 8 and layer3 == 9 ) {
336- return chiSquared < 8 .046f ;
337- } else if (layer1 == 7 and layer2 == 8 and layer3 == 14 ) {
338- return chiSquared < 0 .575f ;
339- } else if (layer1 == 1 and layer2 == 2 and layer3 == 7 ) {
340- return chiSquared < 5 .304f ;
341- } else if (layer1 == 1 and layer2 == 2 and layer3 == 3 ) {
342- return chiSquared < 10 .6211f ;
343- } else if (layer1 == 1 and layer2 == 7 and layer3 == 8 ) {
344- return chiSquared < 4 .617f ;
345- } else if (layer1 == 2 and layer2 == 7 and layer3 == 8 ) {
346- return chiSquared < 8 .046f ;
347- } else if (layer1 == 2 and layer2 == 7 and layer3 == 13 ) {
348- return chiSquared < 0 .435f ;
349- } else if (layer1 == 2 and layer2 == 3 and layer3 == 7 ) {
350- return chiSquared < 9 .244f ;
351- } else if (layer1 == 2 and layer2 == 3 and layer3 == 12 ) {
352- return chiSquared < 0 .287f ;
353- } else if (layer1 == 2 and layer2 == 3 and layer3 == 4 ) {
354- return chiSquared < 18 .509f ;
355- }
356-
357- return true ;
358- }
359-
360- ALPAKA_FN_ACC ALPAKA_FN_INLINE bool passPT3RPhiChiSquaredInwardsCuts (ModulesConst modules,
361- uint16_t lowerModuleIndex1,
362- uint16_t lowerModuleIndex2,
363- uint16_t lowerModuleIndex3,
364- float chiSquared) {
365- const int layer1 =
366- modules.layers ()[lowerModuleIndex1] + 6 * (modules.subdets ()[lowerModuleIndex1] == Endcap) +
367- 5 * (modules.subdets ()[lowerModuleIndex1] == Endcap and modules.moduleType ()[lowerModuleIndex1] == TwoS);
368- const int layer2 =
369- modules.layers ()[lowerModuleIndex2] + 6 * (modules.subdets ()[lowerModuleIndex2] == Endcap) +
370- 5 * (modules.subdets ()[lowerModuleIndex2] == Endcap and modules.moduleType ()[lowerModuleIndex2] == TwoS);
371- const int layer3 =
372- modules.layers ()[lowerModuleIndex3] + 6 * (modules.subdets ()[lowerModuleIndex3] == Endcap) +
373- 5 * (modules.subdets ()[lowerModuleIndex3] == Endcap and modules.moduleType ()[lowerModuleIndex3] == TwoS);
374-
375- if (layer1 == 7 and layer2 == 8 and layer3 == 9 ) // endcap layer 1,2,3, ps
376- {
377- return chiSquared < 22016 .8055f ;
378- } else if (layer1 == 7 and layer2 == 8 and layer3 == 14 ) // endcap layer 1,2,3 layer3->2s
379- {
380- return chiSquared < 935179 .56807f ;
381- } else if (layer1 == 8 and layer2 == 9 and layer3 == 10 ) // endcap layer 2,3,4
382- {
383- return chiSquared < 29064 .12959f ;
384- } else if (layer1 == 8 and layer2 == 9 and layer3 == 15 ) // endcap layer 2,3,4, layer3->2s
385- {
386- return chiSquared < 935179 .5681f ;
387- } else if (layer1 == 1 and layer2 == 2 and layer3 == 3 ) // barrel 1,2,3
388- {
389- return chiSquared < 1370 .0113195101474f ;
390- } else if (layer1 == 1 and layer2 == 2 and layer3 == 7 ) // barrel 1,2 endcap 1
391- {
392- return chiSquared < 5492 .110048314815f ;
393- } else if (layer1 == 2 and layer2 == 3 and layer3 == 4 ) // barrel 2,3,4
394- {
395- return chiSquared < 4160 .410806470067f ;
396- } else if (layer1 == 1 and layer2 == 7 and layer3 == 8 ) // barrel 1, endcap 1,2
397- {
398- return chiSquared < 29064 .129591225726f ;
399- } else if (layer1 == 2 and layer2 == 3 and layer3 == 7 ) // barrel 2,3 endcap 1
400- {
401- return chiSquared < 12634 .215376250893f ;
402- } else if (layer1 == 2 and layer2 == 3 and layer3 == 12 ) // barrel 2,3, endcap 1->2s
403- {
404- return chiSquared < 353821 .69361145404f ;
405- } else if (layer1 == 2 and layer2 == 7 and layer3 == 8 ) // barrel2, endcap 1,2
406- {
407- return chiSquared < 33393 .26076341235f ;
408- } else if (layer1 == 2 and layer2 == 7 and layer3 == 13 ) // barrel 2, endcap 1, endcap2->2s
409- {
410- return chiSquared < 935179 .5680742573f ;
411- }
412-
413- return true ;
414- }
415-
416315 ALPAKA_FN_ACC ALPAKA_FN_INLINE bool checkIntervalOverlappT3 (float firstMin,
417316 float firstMax,
418317 float secondMin,
@@ -630,7 +529,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
630529 return RMSE;
631530 }
632531
633- template <typename TAcc>
532+ template <typename WP = dnn::pt3dnn::pT3WP, typename TAcc>
634533 ALPAKA_FN_ACC ALPAKA_FN_INLINE bool runPixelTripletDefaultAlgo (TAcc const & acc,
635534 ModulesConst modules,
636535 ObjectRangesConst ranges,
@@ -771,30 +670,25 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
771670
772671 rPhiChiSquared =
773672 computePT3RPhiChiSquared (acc, modules, lowerModuleIndices, pixelG, pixelF, pixelRadiusPCA, xs, ys);
774- if (runChiSquaredCuts && pixelSegmentPt < 5 .0f ) {
775- if (!passPT3RPhiChiSquaredCuts (modules, lowerModuleIndex, middleModuleIndex, upperModuleIndex, rPhiChiSquared))
776- return false ;
777- }
778673
779674 rPhiChiSquaredInwards = computePT3RPhiChiSquaredInwards (g, f, tripletRadius, xPix, yPix);
780- if (runChiSquaredCuts && pixelSegmentPt < 5 .0f ) {
781- if (!passPT3RPhiChiSquaredInwardsCuts (
782- modules, lowerModuleIndex, middleModuleIndex, upperModuleIndex, rPhiChiSquaredInwards))
783- return false ;
784- }
785675 }
786676
787677 centerX = 0 ;
788678 centerY = 0 ;
789679
790- if (runDNN and !lst::pt3dnn::runInference (acc,
791- rPhiChiSquared,
792- tripletRadius,
793- pixelRadius,
794- pixelRadiusError,
795- rzChiSquared,
796- pixelSeeds.eta ()[pixelSegmentArrayIndex],
797- pixelSegmentPt)) {
680+ // Module type of last anchor hit for the T3.
681+ const int module_type_3 = modules.moduleType ()[upperModuleIndex];
682+
683+ if (runDNN and !lst::pt3dnn::runInference<WP>(acc,
684+ rPhiChiSquared,
685+ tripletRadius,
686+ pixelRadius,
687+ pixelRadiusError,
688+ rzChiSquared,
689+ pixelSeeds.eta ()[pixelSegmentArrayIndex],
690+ pixelSegmentPt,
691+ module_type_3)) {
798692 return false ;
799693 }
800694
0 commit comments