55#include < stdio.h>
66#include < omp.h>
77#include < chrono>
8+
9+ #ifndef SIMDE_ENABLE_NATIVE_ALIASES
10+ #define SIMDE_ENABLE_NATIVE_ALIASES
11+ #include " simde/x86/avx512.h" // SSE intrinsics
12+ #endif
13+
814#include " CvxCompress.hxx"
915#include " Wavelet_Transform_Fast.hxx"
1016#include " Wavelet_Transform_Slow.hxx" // for comparison in module test
@@ -195,16 +201,55 @@ float CvxCompress::Compress(
195201 long & compressed_length
196202 )
197203{
198- assert (bx >= CvxCompress::Min_BX () && bx <= CvxCompress::Max_BX () && is_pow2 (bx));
199- assert (by >= CvxCompress::Min_BY () && by <= CvxCompress::Max_BY () && is_pow2 (by));
200- assert (bz == 1 || (bz >= CvxCompress::Min_BZ () && bz <= CvxCompress::Max_BZ () && is_pow2 (bz)));
201- float global_rms = use_local_RMS ? 1 .0f : Compute_Global_RMS (vol,nx,ny,nz);
202-
203204 int num_threads;
204205#pragma omp parallel
205206 {
206207 num_threads = omp_get_num_threads ();
207208 }
209+ return Compress (scale,vol,nx,ny,nz,bx,by,bz,use_local_RMS,compressed,num_threads,compressed_length);
210+ }
211+
212+ float CvxCompress::Compress (
213+ float scale,
214+ float * vol,
215+ int nx,
216+ int ny,
217+ int nz,
218+ int bx,
219+ int by,
220+ int bz,
221+ unsigned int * compressed,
222+ int num_threads,
223+ long & compressed_length
224+ )
225+ {
226+ bool use_local_RMS = false ;
227+ return Compress (scale,vol,nx,ny,nz,bx,by,bz,use_local_RMS,compressed,num_threads,compressed_length);
228+ }
229+
230+
231+ float CvxCompress::Compress (
232+ float scale,
233+ float * vol,
234+ int nx,
235+ int ny,
236+ int nz,
237+ int bx,
238+ int by,
239+ int bz,
240+ bool use_local_RMS,
241+ unsigned int * compressed,
242+ int num_threads,
243+ long & compressed_length
244+ )
245+ {
246+ assert (bx >= CvxCompress::Min_BX () && bx <= CvxCompress::Max_BX () && is_pow2 (bx));
247+ assert (by >= CvxCompress::Min_BY () && by <= CvxCompress::Max_BY () && is_pow2 (by));
248+ assert (bz == 1 || (bz >= CvxCompress::Min_BZ () && bz <= CvxCompress::Max_BZ () && is_pow2 (bz)));
249+ float global_rms = use_local_RMS ? 1 .0f : Compute_Global_RMS (vol,nx,ny,nz);
250+
251+ omp_set_num_threads (num_threads);
252+
208253#define MAX (a,b ) (a>b?a:b)
209254 int max_bs = MAX (bx,MAX (by,bz));
210255#undef MAX
@@ -245,6 +290,7 @@ float CvxCompress::Compress(
245290
246291 float glob_mulfac = global_rms != 0 .0f ? 1 .0f / (global_rms * scale) : 1 .0f ;
247292 compressed[6 ] = *((unsigned int *)&glob_mulfac);
293+ // printf("nx=%d, ny=%d, nz=%d, bx=%d, by=%d, bz=%d, mulfac=%e\n",nx,ny,nz,bx,by,bz,glob_mulfac);
248294
249295 // flags:
250296 // 1 -> use local RMS (global RMS otherwise)
@@ -402,10 +448,38 @@ void CvxCompress::Decompress(
402448 unsigned int * compressed,
403449 long compressed_length
404450 )
451+ {
452+ int num_threads;
453+ #pragma omp parallel
454+ {
455+ num_threads = omp_get_num_threads ();
456+ }
457+ return Decompress (vol, nx, ny, nz, compressed, num_threads, compressed_length);
458+ }
459+
460+ void CvxCompress::Decompress (
461+ float *vol,
462+ int nx,
463+ int ny,
464+ int nz,
465+ unsigned int * compressed,
466+ int num_threads,
467+ long compressed_length
468+ )
405469{
406470 int nx_check = ((int *)compressed)[0 ];
407471 int ny_check = ((int *)compressed)[1 ];
408472 int nz_check = ((int *)compressed)[2 ];
473+ // Check sizes and print error message if they don't match.
474+ // for nx ny and nz
475+ if (nx != nx_check || ny != ny_check || nz != nz_check)
476+ {
477+ printf (" Error! Decompress: nx, ny, nz do not match!\n " );
478+ printf (" nx=%d, ny=%d, nz=%d, nx_check=%d, ny_check=%d, nz_check=%d\n " ,nx,ny,nz,nx_check,ny_check,nz_check);
479+ }
480+
481+ omp_set_num_threads (num_threads);
482+
409483 assert (nx == nx_check);
410484 assert (ny == ny_check);
411485 assert (nz == nz_check);
@@ -416,16 +490,16 @@ void CvxCompress::Decompress(
416490 float glob_mulfac = ((float *)compressed)[6 ];
417491 int flags = ((int *)compressed)[7 ];
418492 bool use_local_RMS = (flags & 1 ) ? true : false ;
419- // printf("nx=%d, ny=%d, nz=%d, bx=%d, by=%d, bz=%d, mulfac=%e\n",nx,ny,nz,bx,by,bz,mulfac );
493+ // printf("nx=%d, ny=%d, nz=%d, bx=%d, by=%d, bz=%d, mulfac=%e\n",nx,ny,nz,bx,by,bz,glob_mulfac );
420494
421495 int nbx = (nx+bx-1 )/bx;
422496 int nby = (ny+by-1 )/by;
423497 int nbz = (nz+bz-1 )/bz;
424498 int nnn = nbx*nby*nbz;
425- // printf("nbx=%d, nby=%d, nbz=%d, nnn=%d\n",nbx,nby,nbz,nnn);
499+ // printf("nbx=%d, nby=%d, nbz=%d, nnn=%d\n",nbx,nby,nbz,nnn);
426500
427501 long * glob_blkoffs = (long *)(compressed+8 );
428-
502+
429503 float * blkmulfac = 0L ;
430504 unsigned int * bytes;
431505 if (use_local_RMS)
@@ -439,11 +513,6 @@ void CvxCompress::Decompress(
439513 bytes = (unsigned int *)(glob_blkoffs+nnn);
440514 }
441515
442- int num_threads;
443- #pragma omp parallel
444- {
445- num_threads = omp_get_num_threads ();
446- }
447516#define MAX (a,b ) (a>b?a:b)
448517 int max_bs = MAX (bx,MAX (by,bz));
449518#undef MAX
@@ -1173,8 +1242,6 @@ bool CvxCompress::Run_Module_Tests(bool verbose, bool exhaustive_throughput_test
11731242 return forward_passed && inverse_passed && copy_to_block_passed && copy_from_block_passed && copy_round_trip_passed && global_rms_passed;
11741243}
11751244
1176- extern " C"
1177- {
11781245//
11791246float
11801247cvx_compress (
@@ -1199,7 +1266,7 @@ cvx_decompress_outofplace(
11991266 int *ny,
12001267 int *nz,
12011268 unsigned int *compressed,
1202- long compressed_length)
1269+ long compressed_length)
12031270{
12041271 CvxCompress c;
12051272 return c.Decompress (*nx, *ny, *nz, compressed, compressed_length);
@@ -1208,14 +1275,45 @@ cvx_decompress_outofplace(
12081275void
12091276cvx_decompress_inplace (
12101277 float *vol,
1211- int nx,
1212- int ny,
1213- int nz,
1278+ int nx,
1279+ int ny,
1280+ int nz,
12141281 unsigned int *compressed,
1215- long compressed_length)
1282+ long compressed_length)
12161283{
12171284 CvxCompress c;
12181285 c.Decompress (vol, nx, ny, nz, compressed, compressed_length);
12191286}
1220- //
1287+
1288+ float
1289+ cvx_compress_th (
1290+ float scale,
1291+ float *vol,
1292+ int nx,
1293+ int ny,
1294+ int nz,
1295+ int bx,
1296+ int by,
1297+ int bz,
1298+ unsigned int *compressed,
1299+ int num_threads,
1300+ long *compressed_length)
1301+ {
1302+ CvxCompress c;
1303+ return c.Compress (scale, vol, nx, ny, nz, bx, by, bz, false , compressed, num_threads, *compressed_length);
12211304}
1305+
1306+ void
1307+ cvx_decompress_inplace_th (
1308+ float *vol,
1309+ int nx,
1310+ int ny,
1311+ int nz,
1312+ unsigned int *compressed,
1313+ int num_threads,
1314+ long compressed_length)
1315+ {
1316+ CvxCompress c;
1317+ c.Decompress (vol, nx, ny, nz, compressed, num_threads, compressed_length);
1318+ }
1319+ //
0 commit comments