3434
3535#include < chrono>
3636#include " include/fit_tsne.h"
37+ #include " verify.hpp"
3738
3839// #ifndef DEBUG_TIME
3940// #define DEBUG_TIME
6263#define PRINT_IL_TIMER (x ) std::cout << #x << " : " << ((float )x.count()) / 1000000.0 << " s" << std::endl
6364#endif
6465
65- double tsnecuda::RunTsne (tsnecuda::Options& opt)
66+ double tsnecuda::RunTsne (tsnecuda::Options& opt, int & success )
6667{
6768 std::chrono::steady_clock::time_point time_start_;
6869 std::chrono::steady_clock::time_point time_end_;
@@ -406,8 +407,9 @@ double tsnecuda::RunTsne(tsnecuda::Options& opt)
406407 std::cout << " done." << std::endl;
407408 }
408409
409- // int fft_dimensions[2] = {n_fft_coeffs, n_fft_coeffs}; // {780, 780}
410- // size_t work_size, work_size_dft, work_size_idft;
410+ int fft_dimensions[2 ] = {n_fft_coeffs, n_fft_coeffs}; // {780, 780}
411+ size_t work_size_idft, work_size_dft;
412+ // size_t work_size;
411413
412414 // std::cout << "Setting up dft plans...\n";
413415 // // *** TIMED SEPARATELY. NOT ADDED TO PERF TIME ***
@@ -424,41 +426,41 @@ double tsnecuda::RunTsne(tsnecuda::Options& opt)
424426 // TIME_SINCE(time_start);
425427
426428 // TIME_START();
427- // cufftHandle plan_dft;
428- // CufftSafeCall(cufftCreate(&plan_dft));
429- // CufftSafeCall(cufftMakePlanMany(
430- // plan_dft,
431- // 2,
432- // fft_dimensions,
433- // NULL,
434- // 1,
435- // n_fft_coeffs * n_fft_coeffs,
436- // NULL,
437- // 1,
438- // n_fft_coeffs * (n_fft_coeffs / 2 + 1),
439- // CUFFT_R2C,
440- // n_terms,
441- // &work_size_dft)
442- // );
429+ cufftHandle plan_dft;
430+ CufftSafeCall (cufftCreate (&plan_dft));
431+ CufftSafeCall (cufftMakePlanMany (
432+ plan_dft,
433+ 2 ,
434+ fft_dimensions,
435+ NULL ,
436+ 1 ,
437+ n_fft_coeffs * n_fft_coeffs,
438+ NULL ,
439+ 1 ,
440+ n_fft_coeffs * (n_fft_coeffs / 2 + 1 ),
441+ CUFFT_R2C,
442+ n_terms,
443+ &work_size_dft)
444+ );
443445 // TIME_SINCE(time_start);
444446
445447 // TIME_START();
446- // cufftHandle plan_idft;
447- // CufftSafeCall(cufftCreate(&plan_idft));
448- // CufftSafeCall(cufftMakePlanMany(
449- // plan_idft,
450- // 2,
451- // fft_dimensions,
452- // NULL,
453- // 1,
454- // n_fft_coeffs * (n_fft_coeffs / 2 + 1),
455- // NULL,
456- // 1,
457- // n_fft_coeffs * n_fft_coeffs,
458- // CUFFT_C2R,
459- // n_terms,
460- // &work_size_idft)
461- // );
448+ cufftHandle plan_idft;
449+ CufftSafeCall (cufftCreate (&plan_idft));
450+ CufftSafeCall (cufftMakePlanMany (
451+ plan_idft,
452+ 2 ,
453+ fft_dimensions,
454+ NULL ,
455+ 1 ,
456+ n_fft_coeffs * (n_fft_coeffs / 2 + 1 ),
457+ NULL ,
458+ 1 ,
459+ n_fft_coeffs * n_fft_coeffs,
460+ CUFFT_C2R,
461+ n_terms,
462+ &work_size_idft)
463+ );
462464 // TIME_SINCE(time_start);
463465 // std::cout << "done.\n";
464466
@@ -545,8 +547,8 @@ double tsnecuda::RunTsne(tsnecuda::Options& opt)
545547#endif
546548
547549 tsnecuda::NbodyFFT2D (
548- // plan_dft,
549- // plan_idft,
550+ plan_dft,
551+ plan_idft,
550552 fft_kernel_tilde_device, // input
551553 fft_w_coefficients, // intermediate value
552554 N,
@@ -697,6 +699,9 @@ double tsnecuda::RunTsne(tsnecuda::Options& opt)
697699 dump_file << host_ys[i] << " " << host_ys[i + num_points] << std::endl;
698700 }
699701 dump_file.close ();
702+
703+ std::string golden_file = " ../../data/tsne_mnist_output_golden.txt" ;
704+ success = verify (golden_file, opt.get_dump_file (), 0.2 , 10.0 );
700705 TIMER_END_ ()
701706
702707 host_ys.clear ();
0 commit comments